From b0bf49137bde7066747d71a3d3fb6dbc46556ff8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 6 Dec 2024 16:46:35 +0100 Subject: [PATCH 001/111] Add module 'new_base' --- src/lib.rs | 1 + src/new_base/mod.rs | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 src/new_base/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 0d0a4a2ba..e9aef12b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,6 +193,7 @@ extern crate core; pub mod base; pub mod dep; pub mod net; +pub mod new_base; pub mod rdata; pub mod resolv; pub mod sign; diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs new file mode 100644 index 000000000..4257c2712 --- /dev/null +++ b/src/new_base/mod.rs @@ -0,0 +1,5 @@ +//! Basic DNS. +//! +//! This module provides the essential types and functionality for working +//! with DNS. Most importantly, it provides functionality for parsing and +//! building DNS messages on the wire. From ea600fd3f5e666c2812a29c15011f77190b76042 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 9 Dec 2024 17:36:41 +0100 Subject: [PATCH 002/111] [new_base] Add module 'name' --- src/new_base/mod.rs | 2 ++ src/new_base/name/mod.rs | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 src/new_base/name/mod.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 4257c2712..c29a0b49f 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -3,3 +3,5 @@ //! This module provides the essential types and functionality for working //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. + +pub mod name; diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs new file mode 100644 index 000000000..e288f9ee4 --- /dev/null +++ b/src/new_base/name/mod.rs @@ -0,0 +1,15 @@ +//! Domain names. +//! +//! Domain names are a core concept of DNS. The whole system is essentially +//! just a mapping from domain names to arbitrary information. This module +//! provides types and essential functionality for working with them. +//! +//! A domain name is a sequence of labels, separated by ASCII periods (`.`). +//! For example, `example.org.` contains three labels: `example`, `org`, and +//! `` (the root label). Outside DNS-specific code, the root label (and its +//! separator) are almost always omitted, but keep them in mind here. +//! +//! Domain names form a hierarchy, where `b.a` is the "parent" of `.c.b.a`. +//! The owner of `example.org` is thus responsible for _every_ domain ending +//! with the `.example.org` suffix. The reverse order in which this hierarchy +//! is expressed can sometimes be confusing. From 6fb0957e79691dc30608d6204640e3b2c8bc8f53 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 9 Dec 2024 20:01:52 +0100 Subject: [PATCH 003/111] [new_base/name] Define labels --- src/new_base/name/label.rs | 181 +++++++++++++++++++++++++++++++++++++ src/new_base/name/mod.rs | 3 + 2 files changed, 184 insertions(+) create mode 100644 src/new_base/name/label.rs diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs new file mode 100644 index 000000000..d7d83b6a1 --- /dev/null +++ b/src/new_base/name/label.rs @@ -0,0 +1,181 @@ +//! Labels in domain names. + +//----------- Label ---------------------------------------------------------- + +use core::{ + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, +}; + +/// A label in a domain name. +/// +/// A label contains up to 63 bytes of arbitrary data. +#[repr(transparent)] +pub struct Label([u8]); + +//--- Associated Constants + +impl Label { + /// The root label. + pub const ROOT: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"") } + }; + + /// The wildcard label. + pub const WILDCARD: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"*") } + }; +} + +//--- Construction + +impl Label { + /// Assume a byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } +} + +//--- Inspection + +impl Label { + /// The length of this label, in bytes. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.is_empty() + } + + /// Whether this is a wildcard label. + pub const fn is_wildcard(&self) -> bool { + // NOTE: '==' for byte slices is not 'const'. + self.0.len() == 1 && self.0[0] == b'*' + } + + /// The bytes making up this label. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Access to the underlying bytes + +impl AsRef<[u8]> for Label { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a Label> for &'a [u8] { + fn from(value: &'a Label) -> Self { + &value.0 + } +} + +//--- Comparison + +impl PartialEq for Label { + /// Compare two labels for equality. + /// + /// Labels are compared ASCII-case-insensitively. + fn eq(&self, other: &Self) -> bool { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.eq(that) + } +} + +impl Eq for Label {} + +//--- Ordering + +impl PartialOrd for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn cmp(&self, other: &Self) -> Ordering { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.cmp(that) + } +} + +//--- Hashing + +impl Hash for Label { + /// Hash this label. + /// + /// All uppercase ASCII characters are lowercased beforehand. This way, + /// the hash of a label is case-independent, consistent with how labels + /// are compared and ordered. + /// + /// The label is hashed as if it were a name containing a single label -- + /// the length octet is thus included. This makes the hashing consistent + /// between names and tuples (not slices!) of labels. + fn hash(&self, state: &mut H) { + state.write_u8(self.len() as u8); + for &byte in self.as_bytes() { + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//--- Formatting + +impl fmt::Display for Label { + /// Print a label. + /// + /// The label is printed in the conventional zone file format, with bytes + /// outside printable ASCII formatted as `\\DDD` (a backslash followed by + /// three zero-padded decimal digits), and `.` and `\\` simply escaped by + /// a backslash. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_bytes().iter().try_for_each(|&byte| { + if b".\\".contains(&byte) { + write!(f, "\\{}", byte as char) + } else if byte.is_ascii_graphic() { + write!(f, "{}", byte as char) + } else { + write!(f, "\\{:03}", byte) + } + }) + } +} + +impl fmt::Debug for Label { + /// Print a label for debugging purposes. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Label") + .field(&format_args!("{}", self)) + .finish() + } +} diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index e288f9ee4..1cc63e1cd 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -13,3 +13,6 @@ //! The owner of `example.org` is thus responsible for _every_ domain ending //! with the `.example.org` suffix. The reverse order in which this hierarchy //! is expressed can sometimes be confusing. + +mod label; +pub use label::Label; From 66c4d198d911f25800e9bc2bf1c97b85943b1e42 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:08:43 +0100 Subject: [PATCH 004/111] [new_base] Add module 'message' --- Cargo.lock | 26 +++- Cargo.toml | 6 + src/new_base/message.rs | 284 ++++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 1 + 4 files changed, 315 insertions(+), 2 deletions(-) create mode 100644 src/new_base/message.rs diff --git a/Cargo.lock b/Cargo.lock index 7f844fa92..7506702e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -278,6 +278,8 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", + "zerocopy 0.8.13", + "zerocopy-derive 0.8.13", ] [[package]] @@ -797,7 +799,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -1690,7 +1692,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67914ab451f3bfd2e69e5e9d2ef3858484e7074d63f204fd166ec391b54de21d" +dependencies = [ + "zerocopy-derive 0.8.13", ] [[package]] @@ -1704,6 +1715,17 @@ dependencies = [ "syn", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7988d73a4303ca289df03316bc490e934accf371af6bc745393cf3c2c5c4f25d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 5fb61052e..0072d61fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,12 @@ tokio-stream = { version = "0.1.1", optional = true } tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-filter"] } +# 'zerocopy' provides simple derives for converting types to and from byte +# representations, along with network-endian integer primitives. These are +# used to define simple elements of DNS messages and their serialization. +zerocopy = "0.8" +zerocopy-derive = "0.8" + [features] default = ["std", "rand"] diff --git a/src/new_base/message.rs b/src/new_base/message.rs new file mode 100644 index 000000000..c07d605fa --- /dev/null +++ b/src/new_base/message.rs @@ -0,0 +1,284 @@ +//! DNS message headers. + +use core::fmt; + +use zerocopy::network_endian::U16; +use zerocopy_derive::*; + +//----------- Message -------------------------------------------------------- + +/// A DNS message. +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct Message { + /// The message header. + pub header: Header, + + /// The message contents. + pub contents: [u8], +} + +//----------- Header --------------------------------------------------------- + +/// A DNS message header. +#[derive( + Copy, + Clone, + Debug, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct Header { + /// A unique identifier for the message. + pub id: U16, + + /// Properties of the message. + pub flags: HeaderFlags, + + /// Counts of objects in the message. + pub counts: SectionCounts, +} + +//--- Formatting + +impl fmt::Display for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} of ID {:04X} ({})", + self.flags, + self.id.get(), + self.counts + ) + } +} + +//----------- HeaderFlags ---------------------------------------------------- + +/// DNS message header flags. +#[derive( + Copy, + Clone, + Default, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct HeaderFlags { + inner: U16, +} + +//--- Interaction + +impl HeaderFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether this is a query. + pub fn is_query(&self) -> bool { + !self.get_flag(15) + } + + /// Whether this is a response. + pub fn is_response(&self) -> bool { + self.get_flag(15) + } + + /// The operation code. + pub fn opcode(&self) -> u8 { + (self.inner.get() >> 11) as u8 & 0xF + } + + /// The response code. + pub fn rcode(&self) -> u8 { + self.inner.get() as u8 & 0xF + } + + /// Construct a query. + pub fn query(mut self, opcode: u8) -> Self { + assert!(opcode < 16); + self.inner &= !(0xF << 11); + self.inner |= (opcode as u16) << 11; + self.set_flag(15, false) + } + + /// Construct a response. + pub fn respond(mut self, rcode: u8) -> Self { + assert!(rcode < 16); + self.inner &= !0xF; + self.inner |= rcode as u16; + self.set_flag(15, true) + } + + /// Whether this is an authoritative answer. + pub fn is_authoritative(&self) -> bool { + self.get_flag(10) + } + + /// Mark this as an authoritative answer. + pub fn set_authoritative(self, value: bool) -> Self { + self.set_flag(10, value) + } + + /// Whether this message is truncated. + pub fn is_truncated(&self) -> bool { + self.get_flag(9) + } + + /// Mark this message as truncated. + pub fn set_truncated(self, value: bool) -> Self { + self.set_flag(9, value) + } + + /// Whether the server should query recursively. + pub fn should_recurse(&self) -> bool { + self.get_flag(8) + } + + /// Direct the server to query recursively. + pub fn request_recursion(self, value: bool) -> Self { + self.set_flag(8, value) + } + + /// Whether the server supports recursion. + pub fn can_recurse(&self) -> bool { + self.get_flag(7) + } + + /// Indicate support for recursive queries. + pub fn support_recursion(self, value: bool) -> Self { + self.set_flag(7, value) + } +} + +//--- Formatting + +impl fmt::Debug for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HeaderFlags") + .field("is_response (qr)", &self.is_response()) + .field("opcode", &self.opcode()) + .field("is_authoritative (aa)", &self.is_authoritative()) + .field("is_truncated (tc)", &self.is_truncated()) + .field("should_recurse (rd)", &self.should_recurse()) + .field("can_recurse (ra)", &self.can_recurse()) + .field("rcode", &self.rcode()) + .field("bits", &self.inner.get()) + .finish() + } +} + +impl fmt::Display for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_query() { + if self.should_recurse() { + f.write_str("recursive ")?; + } + write!(f, "query (opcode {})", self.opcode())?; + } else { + if self.is_authoritative() { + f.write_str("authoritative ")?; + } + if self.should_recurse() && self.can_recurse() { + f.write_str("recursive ")?; + } + write!(f, "response (rcode {})", self.rcode())?; + } + + if self.is_truncated() { + f.write_str(" (message truncated)")?; + } + + Ok(()) + } +} + +//----------- SectionCounts -------------------------------------------------- + +/// Counts of objects in a DNS message. +#[derive( + Copy, + Clone, + Debug, + Default, + PartialEq, + Eq, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct SectionCounts { + /// The number of questions in the message. + pub questions: U16, + + /// The number of answer records in the message. + pub answers: U16, + + /// The number of name server records in the message. + pub authorities: U16, + + /// The number of additional records in the message. + pub additional: U16, +} + +//--- Formatting + +impl fmt::Display for SectionCounts { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut some = false; + + for (num, single, many) in [ + (self.questions.get(), "question", "questions"), + (self.answers.get(), "answer", "answers"), + (self.authorities.get(), "authority", "authorities"), + (self.additional.get(), "additional", "additional"), + ] { + // Add a comma if we have printed something before. + if some && num > 0 { + f.write_str(", ")?; + } + + // Print a count of this section. + match num { + 0 => {} + 1 => write!(f, "1 {single}")?, + n => write!(f, "{n} {many}")?, + } + + some |= num > 0; + } + + if !some { + f.write_str("empty")?; + } + + Ok(()) + } +} diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index c29a0b49f..368416354 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,4 +4,5 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. +pub mod message; pub mod name; From 48051b46abb396ac089b74854cff2a3836bcb481 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:09:49 +0100 Subject: [PATCH 005/111] [new_base/name/label] Use 'zerocopy' --- src/new_base/name/label.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index d7d83b6a1..296a167fe 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -1,16 +1,19 @@ //! Labels in domain names. -//----------- Label ---------------------------------------------------------- - use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, }; +use zerocopy_derive::*; + +//----------- Label ---------------------------------------------------------- + /// A label in a domain name. /// /// A label contains up to 63 bytes of arbitrary data. +#[derive(IntoBytes, Immutable, Unaligned)] #[repr(transparent)] pub struct Label([u8]); From be78a8f201e375645eef3ddf54b84a8683aaf60a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:26:00 +0100 Subject: [PATCH 006/111] [new_base] Add module 'parse' --- src/new_base/mod.rs | 6 ++++- src/new_base/parse.rs | 51 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/new_base/parse.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 368416354..8a60f64c9 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,5 +4,9 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. -pub mod message; +mod message; +pub use message::{Header, HeaderFlags, Message, SectionCounts}; + pub mod name; + +pub mod parse; diff --git a/src/new_base/parse.rs b/src/new_base/parse.rs new file mode 100644 index 000000000..2a1697fb2 --- /dev/null +++ b/src/new_base/parse.rs @@ -0,0 +1,51 @@ +//! Parsing DNS messages from the wire format. + +use core::fmt; + +//----------- Low-level parsing traits --------------------------------------- + +/// Parsing from the start of a byte string. +pub trait SplitFrom<'a>: Sized { + /// Parse a value of [`Self`] from the start of the byte string. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; +} + +/// Parsing from a byte string. +pub trait ParseFrom<'a>: Sized { + /// Parse a value of [`Self`] from the given byte string. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_from(bytes: &'a [u8]) -> Result; +} + +//----------- ParseError ----------------------------------------------------- + +/// A DNS parsing error. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ParseError; + +//--- Formatting + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("DNS data could not be parsed from the wire format") + } +} + +//--- Conversion from 'zerocopy' errors + +impl From> for ParseError { + fn from(_: zerocopy::ConvertError) -> Self { + Self + } +} + +impl From> for ParseError { + fn from(_: zerocopy::SizeError) -> Self { + Self + } +} From 7a1a847717912a30a43a22b22c3ffd96799f12b3 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:26:16 +0100 Subject: [PATCH 007/111] [new_base/name] Add module 'parsed' --- src/new_base/name/mod.rs | 3 + src/new_base/name/parsed.rs | 131 ++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 src/new_base/name/parsed.rs diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 1cc63e1cd..9ee96824a 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -16,3 +16,6 @@ mod label; pub use label::Label; + +mod parsed; +pub use parsed::ParsedName; diff --git a/src/new_base/name/parsed.rs b/src/new_base/name/parsed.rs new file mode 100644 index 000000000..abf592e5d --- /dev/null +++ b/src/new_base/name/parsed.rs @@ -0,0 +1,131 @@ +//! Domain names encoded in DNS messages. + +use zerocopy_derive::*; + +use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; + +//----------- ParsedName ----------------------------------------------------- + +/// A domain name in a DNS message. +#[derive(Debug, IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct ParsedName([u8]); + +//--- Constants + +impl ParsedName { + /// The maximum size of a parsed domain name in the wire format. + /// + /// This can occur if a compression pointer is used to point to a root + /// name, even though such a representation is longer than copying the + /// root label into the name. + pub const MAX_SIZE: usize = 256; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl ParsedName { + /// Assume a byte string is a valid [`ParsedName`]. + /// + /// # Safety + /// + /// The byte string must be correctly encoded in the wire format, and + /// within the size restriction (256 bytes or fewer). It must end with a + /// root label or a compression pointer. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'ParsedName' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'ParsedName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl ParsedName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// Whether this is a compression pointer. + pub const fn is_pointer(&self) -> bool { + self.0.len() == 2 + } + + /// The wire format representation of the name. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for &'a ParsedName { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + // Iterate through the labels in the name. + let mut index = 0usize; + loop { + if index >= ParsedName::MAX_SIZE || index >= bytes.len() { + return Err(ParseError); + } + let length = bytes[index]; + if length == 0 { + // This was the root label. + index += 1; + break; + } else if length < 0x40 { + // This was the length of the label. + index += 1 + length as usize; + } else if length >= 0xC0 { + // This was a compression pointer. + if index + 1 >= bytes.len() { + return Err(ParseError); + } + index += 2; + break; + } else { + // This was a reserved or deprecated label type. + return Err(ParseError); + } + } + + let (name, bytes) = bytes.split_at(index); + // SAFETY: 'bytes' has been confirmed to be correctly encoded. + Ok((unsafe { ParsedName::from_bytes_unchecked(name) }, bytes)) + } +} + +impl<'a> ParseFrom<'a> for &'a ParsedName { + fn parse_from(bytes: &'a [u8]) -> Result { + Self::split_from(bytes).and_then(|(name, rest)| { + rest.is_empty().then_some(name).ok_or(ParseError) + }) + } +} + +//--- Conversion to and from bytes + +impl AsRef<[u8]> for ParsedName { + /// The bytes in the name in the wire format. + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a ParsedName> for &'a [u8] { + fn from(name: &'a ParsedName) -> Self { + name.as_bytes() + } +} From ed95534b0ea589a8ec965443a87c066522fd6113 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:31:52 +0100 Subject: [PATCH 008/111] [new_base] Add module 'question' --- src/new_base/mod.rs | 3 ++ src/new_base/question.rs | 104 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 src/new_base/question.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 8a60f64c9..499187fb9 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -9,4 +9,7 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; +mod question; +pub use question::{QClass, QType, Question}; + pub mod parse; diff --git a/src/new_base/question.rs b/src/new_base/question.rs new file mode 100644 index 000000000..16e388c1c --- /dev/null +++ b/src/new_base/question.rs @@ -0,0 +1,104 @@ +//! DNS questions. + +use zerocopy::{network_endian::U16, FromBytes}; +use zerocopy_derive::*; + +use super::{ + name::ParsedName, + parse::{ParseError, ParseFrom, SplitFrom}, +}; + +//----------- Question ------------------------------------------------------- + +/// A DNS question. +pub struct Question<'a> { + /// The domain name being requested. + pub qname: &'a ParsedName, + + /// The type of the requested records. + pub qtype: QType, + + /// The class of the requested records. + pub qclass: QClass, +} + +//--- Construction + +impl<'a> Question<'a> { + /// Construct a new [`Question`]. + pub fn new(qname: &'a ParsedName, qtype: QType, qclass: QClass) -> Self { + Self { + qname, + qtype, + qclass, + } + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for Question<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (qname, rest) = <&ParsedName>::split_from(bytes)?; + let (qtype, rest) = QType::read_from_prefix(rest)?; + let (qclass, rest) = QClass::read_from_prefix(rest)?; + Ok((Self::new(qname, qtype, qclass), rest)) + } +} + +impl<'a> ParseFrom<'a> for Question<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (qname, rest) = <&ParsedName>::split_from(bytes)?; + let (qtype, rest) = QType::read_from_prefix(rest)?; + let qclass = QClass::read_from_bytes(rest)?; + Ok(Self::new(qname, qtype, qclass)) + } +} + +//----------- QType ---------------------------------------------------------- + +/// The type of a question. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct QType { + /// The type code. + pub code: U16, +} + +//----------- QClass --------------------------------------------------------- + +/// The class of a question. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct QClass { + /// The class code. + pub code: U16, +} From 37bc7d2145791de17d535e4b70812c76eb086fea Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 15:59:23 +0100 Subject: [PATCH 009/111] Add module 'record' --- src/new_base/mod.rs | 3 + src/new_base/record/mod.rs | 155 +++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 src/new_base/record/mod.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 499187fb9..f7baced2b 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -12,4 +12,7 @@ pub mod name; mod question; pub use question::{QClass, QType, Question}; +pub mod record; +pub use record::Record; + pub mod parse; diff --git a/src/new_base/record/mod.rs b/src/new_base/record/mod.rs new file mode 100644 index 000000000..fc348b710 --- /dev/null +++ b/src/new_base/record/mod.rs @@ -0,0 +1,155 @@ +//! DNS records. + +use zerocopy::{ + network_endian::{U16, U32}, + FromBytes, +}; +use zerocopy_derive::*; + +use super::{ + name::ParsedName, + parse::{ParseError, ParseFrom, SplitFrom}, +}; + +//----------- Record --------------------------------------------------------- + +/// An unparsed DNS record. +pub struct Record<'a> { + /// The name of the record. + pub rname: &'a ParsedName, + + /// The type of the record. + pub rtype: RType, + + /// The class of the record. + pub rclass: RClass, + + /// How long the record is reliable for. + pub ttl: TTL, + + /// Unparsed record data. + pub rdata: &'a [u8], +} + +//--- Construction + +impl<'a> Record<'a> { + /// Construct a new [`Record`]. + pub fn new( + rname: &'a ParsedName, + rtype: RType, + rclass: RClass, + ttl: TTL, + rdata: &'a [u8], + ) -> Self { + Self { + rname, + rtype, + rclass, + ttl, + rdata, + } + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for Record<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rtype, rest) = RType::read_from_prefix(rest)?; + let (rclass, rest) = RClass::read_from_prefix(rest)?; + let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size = size.get() as usize; + let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) + } +} + +impl<'a> ParseFrom<'a> for Record<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rtype, rest) = RType::read_from_prefix(rest)?; + let (rclass, rest) = RClass::read_from_prefix(rest)?; + let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size = size.get() as usize; + let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + + Ok(Self::new(rname, rtype, rclass, ttl, rdata)) + } +} + +//----------- RType ---------------------------------------------------------- + +/// The type of a record. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct RType { + /// The type code. + pub code: U16, +} + +//----------- RClass --------------------------------------------------------- + +/// The class of a record. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct RClass { + /// The class code. + pub code: U16, +} + +//----------- TTL ------------------------------------------------------------ + +/// How long a record can be cached. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct TTL { + /// The underlying value. + pub value: U32, +} From 26653c391c8c936ccfa7a06fe083e8a7144643d5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 15:59:38 +0100 Subject: [PATCH 010/111] [new_base] Add high-level parsing traits --- src/new_base/parse/message.rs | 49 ++++++++ src/new_base/{parse.rs => parse/mod.rs} | 34 ++++++ src/new_base/parse/question.rs | 148 ++++++++++++++++++++++++ src/new_base/parse/record.rs | 148 ++++++++++++++++++++++++ 4 files changed, 379 insertions(+) create mode 100644 src/new_base/parse/message.rs rename src/new_base/{parse.rs => parse/mod.rs} (62%) create mode 100644 src/new_base/parse/question.rs create mode 100644 src/new_base/parse/record.rs diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs new file mode 100644 index 000000000..eaea9845d --- /dev/null +++ b/src/new_base/parse/message.rs @@ -0,0 +1,49 @@ +//! Parsing DNS messages. + +use core::ops::ControlFlow; + +use crate::new_base::{Header, Question, Record}; + +/// A type that can be constructed by parsing a DNS message. +pub trait ParseMessage<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: VisitMessagePart<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? + type Error; + + /// Construct a visitor, providing the message header. + fn make_visitor(header: &'a Header) + -> Result; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit the components of a DNS message. +pub trait VisitMessagePart<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a component of the message. + fn visit( + &mut self, + component: MessagePart<'a>, + ) -> Result, Self::Error>; +} + +/// A component of a DNS message. +pub enum MessagePart<'a> { + /// A question. + Question(Question<'a>), + + /// An answer record. + Answer(Record<'a>), + + /// An authority record. + Authority(Record<'a>), + + /// An additional record. + Additional(Record<'a>), +} diff --git a/src/new_base/parse.rs b/src/new_base/parse/mod.rs similarity index 62% rename from src/new_base/parse.rs rename to src/new_base/parse/mod.rs index 2a1697fb2..a273717be 100644 --- a/src/new_base/parse.rs +++ b/src/new_base/parse/mod.rs @@ -2,6 +2,17 @@ use core::fmt; +use zerocopy::{FromBytes, Immutable, KnownLayout}; + +mod message; +pub use message::{MessagePart, ParseMessage, VisitMessagePart}; + +mod question; +pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; + +mod record; +pub use record::{ParseRecord, ParseRecords, VisitRecord}; + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. @@ -22,6 +33,29 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +//--- Carrying over 'zerocopy' traits + +// NOTE: We can't carry over 'read_from_prefix' because the trait impls would +// conflict. We kept 'ref_from_prefix' since it's more general. + +impl<'a, T: ?Sized> SplitFrom<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::ref_from_prefix(bytes).map_err(|_| ParseError) + } +} + +impl<'a, T: ?Sized> ParseFrom<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn parse_from(bytes: &'a [u8]) -> Result { + T::ref_from_bytes(bytes).map_err(|_| ParseError) + } +} + //----------- ParseError ----------------------------------------------------- /// A DNS parsing error. diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs new file mode 100644 index 000000000..e08ea6283 --- /dev/null +++ b/src/new_base/parse/question.rs @@ -0,0 +1,148 @@ +//! Parsing DNS questions. + +use core::{convert::Infallible, ops::ControlFlow}; + +#[cfg(feature = "std")] +use std::boxed::Box; +#[cfg(feature = "std")] +use std::vec::Vec; + +use crate::new_base::Question; + +//----------- Trait definitions ---------------------------------------------- + +/// A type that can be constructed by parsing exactly one DNS question. +pub trait ParseQuestion<'a>: Sized { + /// The type of parse errors. + // TODO: Remove entirely? + type Error; + + /// Parse the given DNS question. + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error>; +} + +/// A type that can be constructed by parsing zero or more DNS questions. +pub trait ParseQuestions<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: Default + VisitQuestion<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? Or remove entirely? + type Error; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit DNS questions. +pub trait VisitQuestion<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a question. + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error>; +} + +//----------- Trait implementations ------------------------------------------ + +impl<'a> ParseQuestion<'a> for Question<'a> { + type Error = Infallible; + + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(ControlFlow::Break(question)) + } +} + +//--- Impls for 'Option' + +impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { + type Error = T::Error; + + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { + type Visitor = Option; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { + type Error = T::Error; + + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error> { + if self.is_some() { + return Ok(ControlFlow::Continue(())); + } + + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => { + *self = Some(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Vec' + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { + type Error = T::Error; + + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => { + self.push(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Box<[T]>' + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Box<[T]> { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor.into_boxed_slice()) + } +} diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs new file mode 100644 index 000000000..c93f2f8d1 --- /dev/null +++ b/src/new_base/parse/record.rs @@ -0,0 +1,148 @@ +//! Parsing DNS records. + +use core::{convert::Infallible, ops::ControlFlow}; + +#[cfg(feature = "std")] +use std::boxed::Box; +#[cfg(feature = "std")] +use std::vec::Vec; + +use crate::new_base::Record; + +//----------- Trait definitions ---------------------------------------------- + +/// A type that can be constructed by parsing exactly one DNS record. +pub trait ParseRecord<'a>: Sized { + /// The type of parse errors. + // TODO: Remove entirely? + type Error; + + /// Parse the given DNS record. + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error>; +} + +/// A type that can be constructed by parsing zero or more DNS records. +pub trait ParseRecords<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: Default + VisitRecord<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? Or remove entirely? + type Error; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit DNS records. +pub trait VisitRecord<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a record. + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error>; +} + +//----------- Trait implementations ------------------------------------------ + +impl<'a> ParseRecord<'a> for Record<'a> { + type Error = Infallible; + + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(ControlFlow::Break(record)) + } +} + +//--- Impls for 'Option' + +impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { + type Error = T::Error; + + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Option { + type Visitor = Option; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { + type Error = T::Error; + + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error> { + if self.is_some() { + return Ok(ControlFlow::Continue(())); + } + + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => { + *self = Some(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Vec' + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Vec { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { + type Error = T::Error; + + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => { + self.push(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Box<[T]>' + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Box<[T]> { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor.into_boxed_slice()) + } +} From e1c701ff288918dd6fb0ede61594d36e8fa4632a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 23:15:55 +0100 Subject: [PATCH 011/111] [new_base/name] Add module 'reversed' --- src/new_base/name/label.rs | 69 +++++++++++ src/new_base/name/mod.rs | 6 +- src/new_base/name/reversed.rs | 214 ++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 3 deletions(-) create mode 100644 src/new_base/name/reversed.rs diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 296a167fe..48420df3a 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -4,10 +4,13 @@ use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, + iter::FusedIterator, }; use zerocopy_derive::*; +use crate::new_base::parse::{ParseError, SplitFrom}; + //----------- Label ---------------------------------------------------------- /// A label in a domain name. @@ -47,6 +50,21 @@ impl Label { } } +//--- Parsing + +impl<'a> SplitFrom<'a> for &'a Label { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&size, rest) = bytes.split_first().ok_or(ParseError)?; + if size < 64 && rest.len() >= size as usize { + let (label, rest) = bytes.split_at(1 + size as usize); + // SAFETY: 'label' begins with a valid length octet. + Ok((unsafe { Label::from_bytes_unchecked(label) }, rest)) + } else { + Err(ParseError) + } + } +} + //--- Inspection impl Label { @@ -182,3 +200,54 @@ impl fmt::Debug for Label { .finish() } } + +//----------- LabelIter ------------------------------------------------------ + +/// An iterator over encoded [`Label`]s. +#[derive(Clone)] +pub struct LabelIter<'a> { + /// The buffer being read from. + /// + /// It is assumed to contain valid encoded labels. + bytes: &'a [u8], +} + +//--- Construction + +impl<'a> LabelIter<'a> { + /// Construct a new [`LabelIter`]. + /// + /// The byte string must contain a sequence of valid encoded labels. + pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self { + Self { bytes } + } +} + +//--- Inspection + +impl<'a> LabelIter<'a> { + /// The remaining labels. + pub const fn remaining(&self) -> &'a [u8] { + self.bytes + } +} + +//--- Iteration + +impl<'a> Iterator for LabelIter<'a> { + type Item = &'a Label; + + fn next(&mut self) -> Option { + if self.bytes.is_empty() { + return None; + } + + // SAFETY: 'bytes' is assumed to only contain valid labels. + let (head, tail) = + unsafe { <&Label>::split_from(self.bytes).unwrap_unchecked() }; + self.bytes = tail; + Some(head) + } +} + +impl FusedIterator for LabelIter<'_> {} diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 9ee96824a..9270f4d5c 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -15,7 +15,7 @@ //! is expressed can sometimes be confusing. mod label; -pub use label::Label; +pub use label::{Label, LabelIter}; -mod parsed; -pub use parsed::ParsedName; +mod reversed; +pub use reversed::{RevName, RevNameBuf}; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs new file mode 100644 index 000000000..6283c5322 --- /dev/null +++ b/src/new_base/name/reversed.rs @@ -0,0 +1,214 @@ +//! Reversed DNS names. + +use core::{ + borrow::Borrow, + cmp::Ordering, + hash::{Hash, Hasher}, + ops::Deref, +}; + +use zerocopy_derive::*; + +use super::LabelIter; + +//----------- RevName -------------------------------------------------------- + +/// A domain name in reversed order. +/// +/// Domain names are conventionally presented and encoded from the innermost +/// label to the root label. This ordering is inconvenient and difficult to +/// use, making many common operations (e.g. comparing and ordering domain +/// names) more computationally expensive. A [`RevName`] stores the labels in +/// reversed order for more efficient use. +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct RevName([u8]); + +//--- Constants + +impl RevName { + /// The maximum size of a (reversed) domain name. + /// + /// This is the same as the maximum size of a regular domain name. + pub const MAX_SIZE: usize = 255; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl RevName { + /// Assume a byte string is a valid [`RevName`]. + /// + /// # Safety + /// + /// The byte string must begin with a root label (0-value byte). It must + /// be followed by any number of encoded labels, as long as the size of + /// the whole string is 255 bytes or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'RevName' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'RevName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl RevName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// A byte representation of the [`RevName`]. + /// + /// Note that labels appear in reverse order to the _conventional_ format + /// (it thus starts with the root label). + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// The labels in the [`RevName`]. + /// + /// Note that labels appear in reverse order to the _conventional_ format + /// (it thus starts with the root label). + pub const fn labels(&self) -> LabelIter<'_> { + // SAFETY: A 'RevName' always contains valid encoded labels. + unsafe { LabelIter::new_unchecked(self.as_bytes()) } + } +} + +//--- Equality + +impl PartialEq for RevName { + fn eq(&self, that: &Self) -> bool { + // Instead of iterating labels, blindly iterate bytes. The locations + // of labels don't matter since we're testing everything for equality. + + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = that.as_bytes().iter().map(u8::to_ascii_lowercase); + + this.eq(that) + } +} + +impl Eq for RevName {} + +//--- Comparison + +impl PartialOrd for RevName { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for RevName { + fn cmp(&self, that: &Self) -> Ordering { + // Unfortunately, names cannot be compared bytewise. Labels are + // preceded by their length octets, but a longer label can be less + // than a shorter one if its first bytes are less. We are forced to + // compare lexicographically over labels. + self.labels().cmp(that.labels()) + } +} + +//--- Hashing + +impl Hash for RevName { + fn hash(&self, state: &mut H) { + for byte in self.as_bytes() { + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//----------- RevNameBuf ----------------------------------------------------- + +/// A 256-byte buffer containing a [`RevName`]. +#[derive(Immutable, Unaligned)] +#[repr(C)] // make layout compatible with '[u8; 256]' +pub struct RevNameBuf { + /// The position of the root label in the buffer. + offset: u8, + + /// The buffer containing the [`RevName`]. + buffer: [u8; 255], +} + +//--- Construction + +impl RevNameBuf { + /// Copy a [`RevName`] into a buffer. + pub fn copy_from(name: &RevName) -> Self { + let offset = 255 - name.len() as u8; + let mut buffer = [0u8; 255]; + buffer[offset as usize..].copy_from_slice(name.as_bytes()); + Self { offset, buffer } + } +} + +//--- Access to the underlying 'RevName' + +impl Deref for RevNameBuf { + type Target = RevName; + + fn deref(&self) -> &Self::Target { + let name = &self.buffer[self.offset as usize..]; + // SAFETY: A 'RevNameBuf' always contains a valid 'RevName'. + unsafe { RevName::from_bytes_unchecked(name) } + } +} + +impl Borrow for RevNameBuf { + fn borrow(&self) -> &RevName { + self + } +} + +impl AsRef for RevNameBuf { + fn as_ref(&self) -> &RevName { + self + } +} + +//--- Forwarding equality, comparison, and hashing + +impl PartialEq for RevNameBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for RevNameBuf {} + +impl PartialOrd for RevNameBuf { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for RevNameBuf { + fn cmp(&self, that: &Self) -> Ordering { + (**self).cmp(&**that) + } +} + +impl Hash for RevNameBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} From 7ef218dc5e07c80abdce2e5837e9d4f868bb0969 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:13:40 +0100 Subject: [PATCH 012/111] [new_base/name/reversed] Implement complex parsing --- src/new_base/name/reversed.rs | 193 +++++++++++++++++++++++++++++++++- src/new_base/parse/mod.rs | 32 +++++- 2 files changed, 223 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6283c5322..c884696e9 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -4,11 +4,19 @@ use core::{ borrow::Borrow, cmp::Ordering, hash::{Hash, Hasher}, - ops::Deref, + ops::{Deref, Range}, }; +use zerocopy::IntoBytes; use zerocopy_derive::*; +use crate::new_base::{ + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, +}; + use super::LabelIter; //----------- RevName -------------------------------------------------------- @@ -152,6 +160,14 @@ pub struct RevNameBuf { //--- Construction impl RevNameBuf { + /// Construct an empty, invalid buffer. + fn empty() -> Self { + Self { + offset: 0, + buffer: [0; 255], + } + } + /// Copy a [`RevName`] into a buffer. pub fn copy_from(name: &RevName) -> Self { let offset = 255 - name.len() as u8; @@ -161,6 +177,181 @@ impl RevNameBuf { } } +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for RevNameBuf { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + // NOTE: The input may be controlled by an attacker. Compression + // pointers can be arranged to cause loops or to access every byte in + // the message in random order. Instead of performing complex loop + // detection, which would probably perform allocations, we simply + // disallow a name to point to data _after_ it. Standard name + // compressors will never generate such pointers. + + let message = message.as_bytes(); + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = message.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + let orig_end = message.len() - rest.len(); + + // Traverse compression pointers. + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let bytes = message.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + continue; + } + + // Stop and return the original end. + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok((buffer, orig_end)) + } +} + +impl<'a> ParseFromMessage<'a> for RevNameBuf { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + // See 'split_from_message()' for details. The only differences are + // in the range of the first iteration, and the check that the first + // iteration exactly covers the input range. + + let message = message.as_bytes(); + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = message.get(range.clone()).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + + if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // Traverse compression pointers. + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let bytes = message.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + continue; + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok(buffer) + } +} + +/// Parse an encoded and potentially-compressed domain name, without +/// following any compression pointer. +fn parse_segment<'a>( + mut bytes: &'a [u8], + buffer: &mut RevNameBuf, +) -> Result<(Option, &'a [u8]), ParseError> { + loop { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length == 0 { + // Found the root, stop. + buffer.prepend(&[0u8]); + return Ok((None, rest)); + } else if length < 64 { + // This looks like a regular label. + + if rest.len() < length as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if buffer.offset < 2 + length { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + length as usize); + buffer.prepend(label); + bytes = rest; + } else if length >= 0xC0 { + // This looks like a compression pointer. + + let (&extra, rest) = rest.split_first().ok_or(ParseError)?; + let pointer = u16::from_be_bytes([length, extra]); + + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } else { + // This is an invalid or deprecated label type. + return Err(ParseError); + } + } +} + +//--- Parsing from general byte strings + +impl<'a> SplitFrom<'a> for RevNameBuf { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let mut buffer = Self::empty(); + + let (pointer, rest) = parse_segment(bytes, &mut buffer)?; + if pointer.is_some() { + // We can't follow compression pointers, so fail. + return Err(ParseError); + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok((buffer, rest)) + } +} + +impl<'a> ParseFrom<'a> for RevNameBuf { + fn parse_from(bytes: &'a [u8]) -> Result { + let mut buffer = Self::empty(); + + let (pointer, rest) = parse_segment(bytes, &mut buffer)?; + if pointer.is_some() { + // We can't follow compression pointers, so fail. + return Err(ParseError); + } else if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok(buffer) + } +} + +//--- Interaction + +impl RevNameBuf { + /// Prepend bytes to this buffer. + /// + /// This is an internal convenience function used while building buffers. + fn prepend(&mut self, label: &[u8]) { + self.offset -= label.len() as u8; + self.buffer[self.offset as usize..][..label.len()] + .copy_from_slice(label); + } +} + //--- Access to the underlying 'RevName' impl Deref for RevNameBuf { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index a273717be..fac5a8e9d 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,6 +1,6 @@ //! Parsing DNS messages from the wire format. -use core::fmt; +use core::{fmt, ops::Range}; use zerocopy::{FromBytes, Immutable, KnownLayout}; @@ -13,6 +13,36 @@ pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; mod record; pub use record::{ParseRecord, ParseRecords, VisitRecord}; +use super::Message; + +//----------- Message-aware parsing traits ----------------------------------- + +/// A type that can be parsed from a DNS message. +pub trait SplitFromMessage<'a>: Sized { + /// Parse a value of [`Self`] from the start of a byte string within a + /// particular DNS message. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError>; +} + +/// A type that can be parsed from a string in a DNS message. +pub trait ParseFromMessage<'a>: Sized { + /// Parse a value of [`Self`] from a byte string within a particular DNS + /// message. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result; +} + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. From 3c0b4cd951499507db1cbbacbf537c3b5247f8e4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:23:31 +0100 Subject: [PATCH 013/111] [new_base/name] Add some 'Debug' impls --- src/new_base/name/label.rs | 16 ++++++++++++++++ src/new_base/name/reversed.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 48420df3a..597a5eb92 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -251,3 +251,19 @@ impl<'a> Iterator for LabelIter<'a> { } impl FusedIterator for LabelIter<'_> {} + +//--- Formatting + +impl fmt::Debug for LabelIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Labels<'a>(&'a LabelIter<'a>); + + impl fmt::Debug for Labels<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.clone()).finish() + } + } + + f.debug_tuple("LabelIter").field(&Labels(self)).finish() + } +} diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index c884696e9..1c315ba34 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -3,6 +3,7 @@ use core::{ borrow::Borrow, cmp::Ordering, + fmt, hash::{Hash, Hasher}, ops::{Deref, Range}, }; @@ -144,6 +145,31 @@ impl Hash for RevName { } } +//--- Formatting + +impl fmt::Debug for RevName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct RevLabels<'a>(&'a RevName); + + impl fmt::Debug for RevLabels<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + self.0.labels().try_for_each(|label| { + if !first { + f.write_str(".")?; + } else { + first = false; + } + + label.fmt(f) + }) + } + } + + f.debug_tuple("RevName").field(&RevLabels(self)).finish() + } +} + //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. From 4b8cc79fd22a5f6f1bb19016112829835d6d747b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:50:43 +0100 Subject: [PATCH 014/111] [new_base] Implement parsing for 'Question' and 'Record' --- src/new_base/name/reversed.rs | 2 +- src/new_base/parse/mod.rs | 36 +++++++- src/new_base/question.rs | 80 ++++++++++++++---- src/new_base/record/mod.rs | 150 +++++++++++++++++++++++++++++----- 4 files changed, 229 insertions(+), 39 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 1c315ba34..6c58b7ada 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -328,7 +328,7 @@ fn parse_segment<'a>( } } -//--- Parsing from general byte strings +//--- Parsing from bytes impl<'a> SplitFrom<'a> for RevNameBuf { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index fac5a8e9d..fba82d65c 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -2,7 +2,7 @@ use core::{fmt, ops::Range}; -use zerocopy::{FromBytes, Immutable, KnownLayout}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -43,6 +43,40 @@ pub trait ParseFromMessage<'a>: Sized { ) -> Result; } +//--- Carrying over 'zerocopy' traits + +// NOTE: We can't carry over 'read_from_prefix' because the trait impls would +// conflict. We kept 'ref_from_prefix' since it's more general. + +impl<'a, T: ?Sized> SplitFromMessage<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let message = message.as_bytes(); + let bytes = message.get(start..).ok_or(ParseError)?; + let (this, rest) = T::ref_from_prefix(bytes)?; + Ok((this, message.len() - rest.len())) + } +} + +impl<'a, T: ?Sized> ParseFromMessage<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = message.as_bytes(); + let bytes = message.get(range).ok_or(ParseError)?; + Ok(T::ref_from_bytes(bytes)?) + } +} + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 16e388c1c..121eedff4 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,19 +1,25 @@ //! DNS questions. -use zerocopy::{network_endian::U16, FromBytes}; +use core::ops::Range; + +use zerocopy::network_endian::U16; use zerocopy_derive::*; use super::{ - name::ParsedName, - parse::{ParseError, ParseFrom, SplitFrom}, + name::RevNameBuf, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, }; //----------- Question ------------------------------------------------------- /// A DNS question. -pub struct Question<'a> { +#[derive(Clone)] +pub struct Question { /// The domain name being requested. - pub qname: &'a ParsedName, + pub qname: N, /// The type of the requested records. pub qtype: QType, @@ -22,11 +28,14 @@ pub struct Question<'a> { pub qclass: QClass, } +/// An unparsed DNS question. +pub type UnparsedQuestion = Question; + //--- Construction -impl<'a> Question<'a> { +impl Question { /// Construct a new [`Question`]. - pub fn new(qname: &'a ParsedName, qtype: QType, qclass: QClass) -> Self { + pub fn new(qname: N, qtype: QType, qclass: QClass) -> Self { Self { qname, qtype, @@ -35,22 +44,61 @@ impl<'a> Question<'a> { } } -//--- Parsing +//--- Parsing from DNS messages + +impl<'a, N> SplitFromMessage<'a> for Question +where + N: SplitFromMessage<'a>, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (qname, rest) = N::split_from_message(message, start)?; + let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; + let (&qclass, rest) = <&QClass>::split_from_message(message, rest)?; + Ok((Self::new(qname, qtype, qclass), rest)) + } +} + +impl<'a, N> ParseFromMessage<'a> for Question +where + N: SplitFromMessage<'a>, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (qname, rest) = N::split_from_message(message, range.start)?; + let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; + let &qclass = + <&QClass>::parse_from_message(message, rest..range.end)?; + Ok(Self::new(qname, qtype, qclass)) + } +} + +//--- Parsing from bytes -impl<'a> SplitFrom<'a> for Question<'a> { +impl<'a, N> SplitFrom<'a> for Question +where + N: SplitFrom<'a>, +{ fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = <&ParsedName>::split_from(bytes)?; - let (qtype, rest) = QType::read_from_prefix(rest)?; - let (qclass, rest) = QClass::read_from_prefix(rest)?; + let (qname, rest) = N::split_from(bytes)?; + let (&qtype, rest) = <&QType>::split_from(rest)?; + let (&qclass, rest) = <&QClass>::split_from(rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a> ParseFrom<'a> for Question<'a> { +impl<'a, N> ParseFrom<'a> for Question +where + N: SplitFrom<'a>, +{ fn parse_from(bytes: &'a [u8]) -> Result { - let (qname, rest) = <&ParsedName>::split_from(bytes)?; - let (qtype, rest) = QType::read_from_prefix(rest)?; - let qclass = QClass::read_from_bytes(rest)?; + let (qname, rest) = N::split_from(bytes)?; + let (&qtype, rest) = <&QType>::split_from(rest)?; + let &qclass = <&QClass>::parse_from(rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record/mod.rs b/src/new_base/record/mod.rs index fc348b710..42336dc6f 100644 --- a/src/new_base/record/mod.rs +++ b/src/new_base/record/mod.rs @@ -1,22 +1,31 @@ //! DNS records. +use core::{ + borrow::Borrow, + ops::{Deref, Range}, +}; + use zerocopy::{ network_endian::{U16, U32}, - FromBytes, + FromBytes, IntoBytes, }; use zerocopy_derive::*; use super::{ - name::ParsedName, - parse::{ParseError, ParseFrom, SplitFrom}, + name::RevNameBuf, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, }; //----------- Record --------------------------------------------------------- -/// An unparsed DNS record. -pub struct Record<'a> { +/// A DNS record. +#[derive(Clone)] +pub struct Record { /// The name of the record. - pub rname: &'a ParsedName, + pub rname: N, /// The type of the record. pub rtype: RType, @@ -28,19 +37,22 @@ pub struct Record<'a> { pub ttl: TTL, /// Unparsed record data. - pub rdata: &'a [u8], + pub rdata: D, } +/// An unparsed DNS record. +pub type UnparsedRecord<'a> = Record; + //--- Construction -impl<'a> Record<'a> { +impl Record { /// Construct a new [`Record`]. pub fn new( - rname: &'a ParsedName, + rname: N, rtype: RType, rclass: RClass, ttl: TTL, - rdata: &'a [u8], + rdata: D, ) -> Self { Self { rname, @@ -52,31 +64,35 @@ impl<'a> Record<'a> { } } -//--- Parsing +//--- Parsing from bytes -impl<'a> SplitFrom<'a> for Record<'a> { +impl<'a, N, D> SplitFrom<'a> for Record +where + N: SplitFrom<'a>, + D: SplitFrom<'a>, +{ fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; - let size = size.get() as usize; - let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + let (rdata, rest) = D::split_from(rest)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } -impl<'a> ParseFrom<'a> for Record<'a> { +impl<'a, N, D> ParseFrom<'a> for Record +where + N: SplitFrom<'a>, + D: ParseFrom<'a>, +{ fn parse_from(bytes: &'a [u8]) -> Result { - let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; - let size = size.get() as usize; - let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + let rdata = D::parse_from(rest)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -153,3 +169,95 @@ pub struct TTL { /// The underlying value. pub value: U32, } + +//----------- UnparsedRecordData --------------------------------------------- + +/// Unparsed DNS record data. +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct UnparsedRecordData([u8]); + +//--- Construction + +impl UnparsedRecordData { + /// Assume a byte string is a valid [`UnparsedRecordData`]. + /// + /// # Safety + /// + /// The byte string must be 65,535 bytes or shorter. + pub const unsafe fn new_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'UnparsedRecordData' is 'repr(transparent)' to '[u8]', so + // casting a '[u8]' into an 'UnparsedRecordData' is sound. + core::mem::transmute(bytes) + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a UnparsedRecordData { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let message = message.as_bytes(); + let bytes = message.get(start..).ok_or(ParseError)?; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, message.len() - rest.len())) + } +} + +impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = message.as_bytes(); + let bytes = message.get(range).ok_or(ParseError)?; + Self::parse_from(bytes) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for &'a UnparsedRecordData { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (size, rest) = U16::read_from_prefix(bytes)?; + let size = size.get() as usize; + let (data, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + // SAFETY: 'data.len() == size' which is a 'u16'. + let this = unsafe { UnparsedRecordData::new_unchecked(data) }; + Ok((this, rest)) + } +} + +impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { + fn parse_from(bytes: &'a [u8]) -> Result { + let (size, rest) = U16::read_from_prefix(bytes)?; + let size = size.get() as usize; + let data = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + // SAFETY: 'data.len() == size' which is a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(data) }) + } +} + +//--- Access to the underlying bytes + +impl Deref for UnparsedRecordData { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Borrow<[u8]> for UnparsedRecordData { + fn borrow(&self) -> &[u8] { + self + } +} + +impl AsRef<[u8]> for UnparsedRecordData { + fn as_ref(&self) -> &[u8] { + self + } +} From b56260c95d4e58ab37eedf28410dd4416b6e5349 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:57:39 +0100 Subject: [PATCH 015/111] [new_base/parse] Update with new question/record types --- src/new_base/mod.rs | 4 ++-- src/new_base/name/reversed.rs | 2 +- src/new_base/parse/message.rs | 10 ++++----- src/new_base/parse/question.rs | 38 +++++++++++++++++----------------- src/new_base/parse/record.rs | 18 ++++++++-------- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index f7baced2b..a444989e4 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -10,9 +10,9 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; mod question; -pub use question::{QClass, QType, Question}; +pub use question::{QClass, QType, Question, UnparsedQuestion}; pub mod record; -pub use record::Record; +pub use record::{Record, UnparsedRecord}; pub mod parse; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6c58b7ada..864a5e1bc 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -173,7 +173,7 @@ impl fmt::Debug for RevName { //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. -#[derive(Immutable, Unaligned)] +#[derive(Clone, Immutable, Unaligned)] #[repr(C)] // make layout compatible with '[u8; 256]' pub struct RevNameBuf { /// The position of the root label in the buffer. diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs index eaea9845d..1c964588a 100644 --- a/src/new_base/parse/message.rs +++ b/src/new_base/parse/message.rs @@ -2,7 +2,7 @@ use core::ops::ControlFlow; -use crate::new_base::{Header, Question, Record}; +use crate::new_base::{Header, UnparsedQuestion, UnparsedRecord}; /// A type that can be constructed by parsing a DNS message. pub trait ParseMessage<'a>: Sized { @@ -36,14 +36,14 @@ pub trait VisitMessagePart<'a> { /// A component of a DNS message. pub enum MessagePart<'a> { /// A question. - Question(Question<'a>), + Question(&'a UnparsedQuestion), /// An answer record. - Answer(Record<'a>), + Answer(&'a UnparsedRecord<'a>), /// An authority record. - Authority(Record<'a>), + Authority(&'a UnparsedRecord<'a>), /// An additional record. - Additional(Record<'a>), + Additional(&'a UnparsedRecord<'a>), } diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs index e08ea6283..784cadc09 100644 --- a/src/new_base/parse/question.rs +++ b/src/new_base/parse/question.rs @@ -7,26 +7,26 @@ use std::boxed::Box; #[cfg(feature = "std")] use std::vec::Vec; -use crate::new_base::Question; +use crate::new_base::UnparsedQuestion; //----------- Trait definitions ---------------------------------------------- /// A type that can be constructed by parsing exactly one DNS question. -pub trait ParseQuestion<'a>: Sized { +pub trait ParseQuestion: Sized { /// The type of parse errors. // TODO: Remove entirely? type Error; /// Parse the given DNS question. fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error>; } /// A type that can be constructed by parsing zero or more DNS questions. -pub trait ParseQuestions<'a>: Sized { +pub trait ParseQuestions: Sized { /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitQuestion<'a>; + type Visitor: Default + VisitQuestion; /// The type of errors from converting a visitor into [`Self`]. // TODO: Just use 'Visitor::Error'? Or remove entirely? @@ -37,36 +37,36 @@ pub trait ParseQuestions<'a>: Sized { } /// A type that can visit DNS questions. -pub trait VisitQuestion<'a> { +pub trait VisitQuestion { /// The type of errors produced by visits. type Error; /// Visit a question. fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error>; } //----------- Trait implementations ------------------------------------------ -impl<'a> ParseQuestion<'a> for Question<'a> { +impl ParseQuestion for UnparsedQuestion { type Error = Infallible; fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { - Ok(ControlFlow::Break(question)) + Ok(ControlFlow::Break(question.clone())) } } //--- Impls for 'Option' -impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { +impl ParseQuestion for Option { type Error = T::Error; fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { Ok(match T::parse_question(question)? { ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), @@ -75,7 +75,7 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { } } -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { +impl ParseQuestions for Option { type Visitor = Option; type Error = Infallible; @@ -84,12 +84,12 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { } } -impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { +impl VisitQuestion for Option { type Error = T::Error; fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { if self.is_some() { return Ok(ControlFlow::Continue(())); @@ -108,7 +108,7 @@ impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { //--- Impls for 'Vec' #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { +impl ParseQuestions for Vec { type Visitor = Vec; type Error = Infallible; @@ -118,12 +118,12 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { } #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { +impl VisitQuestion for Vec { type Error = T::Error; fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { Ok(match T::parse_question(question)? { ControlFlow::Break(elem) => { @@ -138,7 +138,7 @@ impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { //--- Impls for 'Box<[T]>' #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Box<[T]> { +impl ParseQuestions for Box<[T]> { type Visitor = Vec; type Error = Infallible; diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs index c93f2f8d1..75e98a36a 100644 --- a/src/new_base/parse/record.rs +++ b/src/new_base/parse/record.rs @@ -7,7 +7,7 @@ use std::boxed::Box; #[cfg(feature = "std")] use std::vec::Vec; -use crate::new_base::Record; +use crate::new_base::UnparsedRecord; //----------- Trait definitions ---------------------------------------------- @@ -19,7 +19,7 @@ pub trait ParseRecord<'a>: Sized { /// Parse the given DNS record. fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error>; } @@ -44,19 +44,19 @@ pub trait VisitRecord<'a> { /// Visit a record. fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error>; } //----------- Trait implementations ------------------------------------------ -impl<'a> ParseRecord<'a> for Record<'a> { +impl<'a> ParseRecord<'a> for UnparsedRecord<'a> { type Error = Infallible; fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { - Ok(ControlFlow::Break(record)) + Ok(ControlFlow::Break(record.clone())) } } @@ -66,7 +66,7 @@ impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { type Error = T::Error; fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { Ok(match T::parse_record(record)? { ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), @@ -89,7 +89,7 @@ impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { if self.is_some() { return Ok(ControlFlow::Continue(())); @@ -123,7 +123,7 @@ impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { Ok(match T::parse_record(record)? { ControlFlow::Break(elem) => { From 1622e586b5c5273bb794f48abea89d490249f5ca Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 01:00:48 +0100 Subject: [PATCH 016/111] [new_base/name] Fix bugs (thanks clippy) --- src/new_base/name/label.rs | 2 ++ src/new_base/name/reversed.rs | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 597a5eb92..b93b32f80 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -217,6 +217,8 @@ pub struct LabelIter<'a> { impl<'a> LabelIter<'a> { /// Construct a new [`LabelIter`]. /// + /// # Safety + /// /// The byte string must contain a sequence of valid encoded labels. pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self { Self { bytes } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 864a5e1bc..60a640f91 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -226,15 +226,17 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { let orig_end = message.len() - rest.len(); // Traverse compression pointers. + let mut old_start = start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. - if start >= start { + if start >= old_start { return Err(ParseError); } // Keep going, from the referenced position. let bytes = message.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; continue; } @@ -267,15 +269,17 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Traverse compression pointers. + let mut old_start = range.start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. - if start >= start { + if start >= old_start { return Err(ParseError); } // Keep going, from the referenced position. let bytes = message.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; continue; } From f27f922eba6ce469c4171dab08129c3413c46f4c Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 01:04:03 +0100 Subject: [PATCH 017/111] [new_base] Make 'record' a private module --- src/new_base/mod.rs | 6 ++++-- src/new_base/{record/mod.rs => record.rs} | 0 2 files changed, 4 insertions(+), 2 deletions(-) rename src/new_base/{record/mod.rs => record.rs} (100%) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index a444989e4..4307896fb 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -12,7 +12,9 @@ pub mod name; mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; -pub mod record; -pub use record::{Record, UnparsedRecord}; +mod record; +pub use record::{ + RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, +}; pub mod parse; diff --git a/src/new_base/record/mod.rs b/src/new_base/record.rs similarity index 100% rename from src/new_base/record/mod.rs rename to src/new_base/record.rs From dd123c12055c5b7dfbca2ed94cf53f2f4aac1da1 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 13:05:28 +0100 Subject: [PATCH 018/111] Use 'zerocopy' 0.8.5 or newer It implements 'Hash' for the provided integer types. --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0072d61fa..9d078a5e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,8 +51,8 @@ tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-fil # 'zerocopy' provides simple derives for converting types to and from byte # representations, along with network-endian integer primitives. These are # used to define simple elements of DNS messages and their serialization. -zerocopy = "0.8" -zerocopy-derive = "0.8" +zerocopy = "0.8.5" +zerocopy-derive = "0.8.5" [features] default = ["std", "rand"] From 46b2e45873879fc6b3b0019c0ec8b32f8b032c52 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:33:26 +0100 Subject: [PATCH 019/111] Add module 'new_rdata' with most RFC 1035 types --- src/lib.rs | 1 + src/new_rdata/mod.rs | 3 + src/new_rdata/rfc1035.rs | 319 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 src/new_rdata/mod.rs create mode 100644 src/new_rdata/rfc1035.rs diff --git a/src/lib.rs b/src/lib.rs index e9aef12b8..b2f7ac66c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,6 +194,7 @@ pub mod base; pub mod dep; pub mod net; pub mod new_base; +pub mod new_rdata; pub mod rdata; pub mod resolv; pub mod sign; diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs new file mode 100644 index 000000000..54afb39ee --- /dev/null +++ b/src/new_rdata/mod.rs @@ -0,0 +1,3 @@ +//! Record data types. + +pub mod rfc1035; diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs new file mode 100644 index 000000000..b8d893dff --- /dev/null +++ b/src/new_rdata/rfc1035.rs @@ -0,0 +1,319 @@ +//! Core record data types. + +use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; + +use zerocopy::network_endian::{U16, U32}; +use zerocopy_derive::*; + +use crate::new_base::{ + parse::{ParseError, ParseFromMessage, SplitFromMessage}, + Message, +}; + +//----------- A -------------------------------------------------------------- + +/// The IPv4 address of a host responsible for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct A { + /// The IPv4 address octets. + pub octets: [u8; 4], +} + +//--- Converting to and from 'Ipv4Addr' + +impl From for A { + fn from(value: Ipv4Addr) -> Self { + Self { + octets: value.octets(), + } + } +} + +impl From for Ipv4Addr { + fn from(value: A) -> Self { + Self::from(value.octets) + } +} + +//--- Parsing from a string + +impl FromStr for A { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ipv4Addr::from_str(s).map(A::from) + } +} + +//--- Formatting + +impl fmt::Display for A { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ipv4Addr::from(*self).fmt(f) + } +} + +//----------- Ns ------------------------------------------------------------- + +/// The authoritative name server for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Ns { + /// The name of the authoritative server. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +//----------- Cname ---------------------------------------------------------- + +/// The canonical name for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Cname { + /// The canonical name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +//----------- Soa ------------------------------------------------------------ + +/// The start of a zone of authority. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct Soa { + /// The name server which provided this zone. + pub mname: N, + + /// The mailbox of the maintainer of this zone. + pub rname: N, + + /// The version number of the original copy of this zone. + // TODO: Define a dedicated serial number type. + pub serial: U32, + + /// The number of seconds to wait until refreshing the zone. + pub refresh: U32, + + /// The number of seconds to wait until retrying a failed refresh. + pub retry: U32, + + /// The number of seconds until the zone is considered expired. + pub expire: U32, + + /// The minimum TTL for any record in this zone. + pub minimum: U32, +} + +//--- Parsing from DNS messages + +impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (mname, rest) = N::split_from_message(message, range.start)?; + let (rname, rest) = N::split_from_message(message, rest)?; + let (&serial, rest) = <&U32>::split_from_message(message, rest)?; + let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; + let (&retry, rest) = <&U32>::split_from_message(message, rest)?; + let (&expire, rest) = <&U32>::split_from_message(message, rest)?; + let &minimum = <&U32>::parse_from_message(message, rest..range.end)?; + + Ok(Self { + mname, + rname, + serial, + refresh, + retry, + expire, + minimum, + }) + } +} + +//----------- Wks ------------------------------------------------------------ + +/// Well-known services supported on this domain. +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct Wks { + /// The address of the host providing these services. + pub address: A, + + /// The IP protocol number for the services (e.g. TCP). + pub protocol: u8, + + /// A bitset of supported well-known ports. + pub ports: [u8], +} + +//--- Formatting + +impl fmt::Debug for Wks { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Ports<'a>(&'a [u8]); + + impl fmt::Debug for Ports<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let entries = self + .0 + .iter() + .enumerate() + .flat_map(|(i, &b)| (0..8).map(move |j| (i, j, b))) + .filter(|(_, j, b)| b & (1 << j) != 0) + .map(|(i, j, _)| i * 8 + j); + + f.debug_set().entries(entries).finish() + } + } + + f.debug_struct("Wks") + .field("address", &Ipv4Addr::from(self.address)) + .field("protocol", &self.protocol) + .field("ports", &Ports(&self.ports)) + .finish() + } +} + +//----------- Ptr ------------------------------------------------------------ + +/// A pointer to another domain name. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Ptr { + /// The referenced domain name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +// TODO: MINFO, HINFO, and TXT records, which need 'CharStr'. + +//----------- Mx ------------------------------------------------------------- + +/// A host that can exchange mail for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct Mx { + /// The preference for this host over others. + pub preference: U16, + + /// The domain name of the mail exchanger. + pub exchange: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (&preference, rest) = + <&U16>::split_from_message(message, range.start)?; + let exchange = N::parse_from_message(message, rest..range.end)?; + Ok(Self { + preference, + exchange, + }) + } +} From 9e238ae43263172fc60ddf59229e2b436db020cd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:43:57 +0100 Subject: [PATCH 020/111] [new_base] Define 'CharStr' --- src/new_base/charstr.rs | 78 +++++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 3 ++ 2 files changed, 81 insertions(+) create mode 100644 src/new_base/charstr.rs diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs new file mode 100644 index 000000000..62d289718 --- /dev/null +++ b/src/new_base/charstr.rs @@ -0,0 +1,78 @@ +//! DNS "character strings". + +use core::ops::Range; + +use zerocopy::IntoBytes; +use zerocopy_derive::*; + +use super::{ + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, +}; + +//----------- CharStr -------------------------------------------------------- + +/// A DNS "character string". +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct CharStr { + /// The underlying octets. + pub octets: [u8], +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a CharStr { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = &message.as_bytes()[start..]; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, bytes.len() - rest.len())) + } +} + +impl<'a> ParseFromMessage<'a> for &'a CharStr { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for &'a CharStr { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize > rest.len() { + return Err(ParseError); + } + let (bytes, rest) = rest.split_at(length as usize); + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok((unsafe { core::mem::transmute::<&[u8], Self>(bytes) }, rest)) + } +} + +impl<'a> ParseFrom<'a> for &'a CharStr { + fn parse_from(bytes: &'a [u8]) -> Result { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize != rest.len() { + return Err(ParseError); + } + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&[u8], Self>(rest) }) + } +} + +// TODO: Formatting diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 4307896fb..428584b68 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -9,6 +9,9 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; +mod charstr; +pub use charstr::CharStr; + mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; From bb63a1ee8ba95ef99f95721030ca4931b82fe860 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:55:55 +0100 Subject: [PATCH 021/111] [new_rdata] Define 'Hinfo' --- src/new_rdata/rfc1035.rs | 49 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index b8d893dff..52bfd02df 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -2,12 +2,17 @@ use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; -use zerocopy::network_endian::{U16, U32}; +use zerocopy::{ + network_endian::{U16, U32}, + IntoBytes, +}; use zerocopy_derive::*; use crate::new_base::{ - parse::{ParseError, ParseFromMessage, SplitFromMessage}, - Message, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + CharStr, Message, }; //----------- A -------------------------------------------------------------- @@ -272,7 +277,41 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } -// TODO: MINFO, HINFO, and TXT records, which need 'CharStr'. +//----------- Hinfo ---------------------------------------------------------- + +/// Information about the host computer. +pub struct Hinfo<'a> { + /// The CPU type. + pub cpu: &'a CharStr, + + /// The OS type. + pub os: &'a CharStr, +} + +//--- Parsing from DNS messages + +impl<'a> ParseFromMessage<'a> for Hinfo<'a> { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> ParseFrom<'a> for Hinfo<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (cpu, rest) = <&CharStr>::split_from(bytes)?; + let os = <&CharStr>::parse_from(rest)?; + Ok(Self { cpu, os }) + } +} //----------- Mx ------------------------------------------------------------- @@ -317,3 +356,5 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { }) } } + +// TODO: TXT records. From 5ec06dca109fbb81106ccea92358582f543addd8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 16:41:15 +0100 Subject: [PATCH 022/111] [new_rdata/rfc1035] Implement (basic) 'Txt' records --- src/new_rdata/rfc1035.rs | 42 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 52bfd02df..80db47141 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -357,4 +357,44 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } -// TODO: TXT records. +//----------- Txt ------------------------------------------------------------ + +/// Free-form text strings about this domain. +#[derive(IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct Txt { + /// The text strings, as concatenated [`CharStr`]s. + content: [u8], +} + +// TODO: Support for iterating over the contained 'CharStr's. + +//--- Parsing from DNS messages + +impl<'a> ParseFromMessage<'a> for &'a Txt { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> ParseFrom<'a> for &'a Txt { + fn parse_from(bytes: &'a [u8]) -> Result { + // NOTE: The input must contain at least one 'CharStr'. + let (_, mut rest) = <&CharStr>::split_from(bytes)?; + while !rest.is_empty() { + (_, rest) = <&CharStr>::split_from(rest)?; + } + + // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) + } +} From 6ad095eed19a5b82ea357a3c98edd108e6bb7c6b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 16:48:49 +0100 Subject: [PATCH 023/111] [new_rdata/rfc1035] Add 'ParseFrom' impls where missing --- src/new_rdata/rfc1035.rs | 61 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 80db47141..4d3a07c47 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -108,6 +108,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -143,6 +151,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. @@ -198,6 +214,30 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { } } +//--- Parsing from bytes + +impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { + fn parse_from(bytes: &'a [u8]) -> Result { + let (mname, rest) = N::split_from(bytes)?; + let (rname, rest) = N::split_from(rest)?; + let (&serial, rest) = <&U32>::split_from(rest)?; + let (&refresh, rest) = <&U32>::split_from(rest)?; + let (&retry, rest) = <&U32>::split_from(rest)?; + let (&expire, rest) = <&U32>::split_from(rest)?; + let &minimum = <&U32>::parse_from(rest)?; + + Ok(Self { + mname, + rname, + serial, + refresh, + retry, + expire, + minimum, + }) + } +} + //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. @@ -277,6 +317,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. @@ -357,6 +405,19 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { + fn parse_from(bytes: &'a [u8]) -> Result { + let (&preference, rest) = <&U16>::split_from(bytes)?; + let exchange = N::parse_from(rest)?; + Ok(Self { + preference, + exchange, + }) + } +} + //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. From 90e15bb7dae03605077eaa579eec1b6a3a6ee202 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:17:07 +0100 Subject: [PATCH 024/111] [new_rdata/rfc1035] Don't use 'zerocopy' around names --- src/new_rdata/rfc1035.rs | 64 +++------------------------------------- 1 file changed, 4 insertions(+), 60 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 4d3a07c47..9415f5362 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -76,21 +76,7 @@ impl fmt::Display for A { //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Ns { /// The name of the authoritative server. @@ -119,21 +105,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Cname { /// The canonical name. @@ -285,21 +257,7 @@ impl fmt::Debug for Wks { //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Ptr { /// The referenced domain name. @@ -364,21 +322,7 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] pub struct Mx { /// The preference for this host over others. From c86a57e8373d487e5afc3bd57c2fa92d618ae550 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:18:01 +0100 Subject: [PATCH 025/111] [new_base/charstr] Impl 'Eq' and 'Debug' --- src/new_base/charstr.rs | 52 ++++++++++++++++++++++++++++++++++++++-- src/new_rdata/rfc1035.rs | 1 + 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 62d289718..dfe2bc8f9 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -1,6 +1,6 @@ //! DNS "character strings". -use core::ops::Range; +use core::{fmt, ops::Range}; use zerocopy::IntoBytes; use zerocopy_derive::*; @@ -75,4 +75,52 @@ impl<'a> ParseFrom<'a> for &'a CharStr { } } -// TODO: Formatting +//--- Equality + +impl PartialEq for CharStr { + fn eq(&self, other: &Self) -> bool { + self.octets.eq_ignore_ascii_case(&other.octets) + } +} + +impl Eq for CharStr {} + +//--- Formatting + +impl fmt::Debug for CharStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use fmt::Write; + + struct Native<'a>(&'a [u8]); + impl fmt::Debug for Native<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("b\"")?; + for &b in self.0 { + f.write_str(match b { + b'"' => "\\\"", + b' ' => " ", + b'\n' => "\\n", + b'\r' => "\\r", + b'\t' => "\\t", + b'\\' => "\\\\", + + _ => { + if b.is_ascii_graphic() { + f.write_char(b as char)?; + } else { + write!(f, "\\x{:02X}", b)?; + } + continue; + } + })?; + } + f.write_char('"')?; + Ok(()) + } + } + + f.debug_struct("CharStr") + .field("content", &Native(&self.octets)) + .finish() + } +} diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 9415f5362..e54c348be 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -286,6 +286,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Hinfo<'a> { /// The CPU type. pub cpu: &'a CharStr, From a607c7748790d6ee1850f161feeb58e57843faaf Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:36:11 +0100 Subject: [PATCH 026/111] [new_base] Add module 'serial' --- src/new_base/mod.rs | 3 ++ src/new_base/serial.rs | 87 ++++++++++++++++++++++++++++++++++++++++ src/new_rdata/rfc1035.rs | 9 ++--- 3 files changed, 94 insertions(+), 5 deletions(-) create mode 100644 src/new_base/serial.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 428584b68..3becedb29 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -21,3 +21,6 @@ pub use record::{ }; pub mod parse; + +mod serial; +pub use serial::Serial; diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs new file mode 100644 index 000000000..fe00923c3 --- /dev/null +++ b/src/new_base/serial.rs @@ -0,0 +1,87 @@ +//! Serial number arithmetic. +//! +//! See [RFC 1982](https://datatracker.ietf.org/doc/html/rfc1982). + +use core::{ + cmp::Ordering, + fmt, + ops::{Add, AddAssign}, +}; + +use zerocopy::network_endian::U32; +use zerocopy_derive::*; + +//----------- Serial --------------------------------------------------------- + +/// A serial number. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Serial(U32); + +//--- Addition + +impl Add for Serial { + type Output = Self; + + fn add(self, rhs: i32) -> Self::Output { + self.0.get().wrapping_add_signed(rhs).into() + } +} + +impl AddAssign for Serial { + fn add_assign(&mut self, rhs: i32) { + self.0 = self.0.get().wrapping_add_signed(rhs).into(); + } +} + +//--- Ordering + +impl PartialOrd for Serial { + fn partial_cmp(&self, other: &Self) -> Option { + let (lhs, rhs) = (self.0.get(), other.0.get()); + + if lhs == rhs { + Some(Ordering::Equal) + } else if lhs.abs_diff(rhs) == 1 << 31 { + None + } else if (lhs < rhs) ^ (lhs.abs_diff(rhs) > (1 << 31)) { + Some(Ordering::Less) + } else { + Some(Ordering::Greater) + } + } +} + +//--- Conversion to and from native integer types + +impl From for Serial { + fn from(value: u32) -> Self { + Self(U32::new(value)) + } +} + +impl From for u32 { + fn from(value: Serial) -> Self { + value.0.get() + } +} + +//--- Formatting + +impl fmt::Display for Serial { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(f) + } +} diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index e54c348be..4e09c9727 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -12,7 +12,7 @@ use crate::new_base::{ parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, - CharStr, Message, + CharStr, Message, Serial, }; //----------- A -------------------------------------------------------------- @@ -143,8 +143,7 @@ pub struct Soa { pub rname: N, /// The version number of the original copy of this zone. - // TODO: Define a dedicated serial number type. - pub serial: U32, + pub serial: Serial, /// The number of seconds to wait until refreshing the zone. pub refresh: U32, @@ -168,7 +167,7 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { ) -> Result { let (mname, rest) = N::split_from_message(message, range.start)?; let (rname, rest) = N::split_from_message(message, rest)?; - let (&serial, rest) = <&U32>::split_from_message(message, rest)?; + let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; let (&retry, rest) = <&U32>::split_from_message(message, rest)?; let (&expire, rest) = <&U32>::split_from_message(message, rest)?; @@ -192,7 +191,7 @@ impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { fn parse_from(bytes: &'a [u8]) -> Result { let (mname, rest) = N::split_from(bytes)?; let (rname, rest) = N::split_from(rest)?; - let (&serial, rest) = <&U32>::split_from(rest)?; + let (&serial, rest) = <&Serial>::split_from(rest)?; let (&refresh, rest) = <&U32>::split_from(rest)?; let (&retry, rest) = <&U32>::split_from(rest)?; let (&expire, rest) = <&U32>::split_from(rest)?; From 7731a35d6be33d5aebdcceeb0bd844999691c725 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:29:26 +0100 Subject: [PATCH 027/111] [new_base] Add module 'build' --- src/new_base/build/mod.rs | 41 +++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 19 ++++++++++++------ src/new_base/parse/mod.rs | 2 +- 3 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 src/new_base/build/mod.rs diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs new file mode 100644 index 000000000..e0a1c4925 --- /dev/null +++ b/src/new_base/build/mod.rs @@ -0,0 +1,41 @@ +//! Building DNS messages in the wire format. + +use core::fmt; + +//----------- Low-level building traits -------------------------------------- + +/// Building into a byte string. +pub trait BuildInto { + /// Append this value to the byte string. + /// + /// If the byte string is long enough to fit the message, the remaining + /// (unfilled) part of the byte string is returned. Otherwise, a + /// [`TruncationError`] is returned. + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError>; +} + +impl BuildInto for &T { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_into(bytes) + } +} + +//----------- TruncationError ------------------------------------------------ + +/// A DNS message did not fit in a buffer. +#[derive(Clone, Debug, PartialEq, Hash)] +pub struct TruncationError; + +//--- Formatting + +impl fmt::Display for TruncationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A buffer was too small to fit a DNS message") + } +} diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 3becedb29..899225cf8 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,14 +4,11 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. +//--- DNS messages + mod message; pub use message::{Header, HeaderFlags, Message, SectionCounts}; -pub mod name; - -mod charstr; -pub use charstr::CharStr; - mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; @@ -20,7 +17,17 @@ pub use record::{ RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, }; -pub mod parse; +//--- Elements of DNS messages + +pub mod name; + +mod charstr; +pub use charstr::CharStr; mod serial; pub use serial::Serial; + +//--- Wire format + +pub mod build; +pub mod parse; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index fba82d65c..022ff9df2 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -122,7 +122,7 @@ where //----------- ParseError ----------------------------------------------------- -/// A DNS parsing error. +/// A DNS message parsing error. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct ParseError; From c27dd1f94a32635c3452a7f926196768c6c6a592 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:32:30 +0100 Subject: [PATCH 028/111] [new_base/build] Add a 'Builder' for DNS messages --- src/new_base/build/builder.rs | 353 ++++++++++++++++++++++++++++++++++ src/new_base/build/mod.rs | 27 +++ 2 files changed, 380 insertions(+) create mode 100644 src/new_base/build/builder.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs new file mode 100644 index 000000000..7488f8f6a --- /dev/null +++ b/src/new_base/build/builder.rs @@ -0,0 +1,353 @@ +//! A builder for DNS messages. + +use core::{ + marker::PhantomData, + mem::ManuallyDrop, + ptr::{self, NonNull}, +}; + +use zerocopy::{FromBytes, IntoBytes, SizeError}; + +use crate::new_base::{name::RevName, Header, Message}; + +use super::TruncationError; + +//----------- Builder -------------------------------------------------------- + +/// A DNS message builder. +pub struct Builder<'b> { + /// The message being built. + /// + /// The message is divided into four parts: + /// + /// - The message header (borrowed mutably by this type). + /// - Committed message contents (borrowed *immutably* by this type). + /// - Appended message contents (borrowed mutably by this type). + /// - Uninitialized message contents (borrowed mutably by this type). + message: NonNull, + + _message: PhantomData<&'b mut Message>, + + /// Context for building. + context: &'b mut BuilderContext, + + /// The commit point of this builder. + /// + /// Message contents up to this point are committed and cannot be removed + /// by this builder. Message contents following this (up to the size in + /// the builder context) are appended but uncommitted. + commit: usize, +} + +//--- Initialization + +impl<'b> Builder<'b> { + /// Construct a [`Builder`] from raw parts. + /// + /// # Safety + /// + /// - `message` is a valid reference for the lifetime `'b`. + /// - `message.header` is mutably borrowed for `'b`. + /// - `message.contents[..commit]` is immutably borrowed for `'b`. + /// - `message.contents[commit..]` is mutably borrowed for `'b`. + /// + /// - `message` and `context` are paired together. + /// + /// - `commit` is at most `context.size()`, which is at most + /// `context.max_size()`. + pub unsafe fn from_raw_parts( + message: NonNull, + context: &'b mut BuilderContext, + commit: usize, + ) -> Self { + Self { + message, + _message: PhantomData, + context, + commit, + } + } + + /// Initialize an empty [`Builder`]. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + assert!(buffer.len() >= 12); + let message = Message::mut_from_bytes(buffer) + .map_err(SizeError::from) + .expect("A 'Message' can fit in 12 bytes"); + context.size = 0; + context.max_size = message.contents.len(); + + // SAFETY: 'message' and 'context' are now consistent. + unsafe { Self::from_raw_parts(message.into(), context, 0) } + } +} + +//--- Inspection + +impl<'b> Builder<'b> { + /// The message header. + /// + /// The header can be modified by the builder, and so is only available + /// for a short lifetime. Note that it implements [`Copy`]. + pub fn header(&self) -> &Header { + // SAFETY: 'message.header' is mutably borrowed by 'self'. + unsafe { &(*self.message.as_ptr()).header } + } + + /// Mutable access to the message header. + pub fn header_mut(&mut self) -> &mut Header { + // SAFETY: 'message.header' is mutably borrowed by 'self'. + unsafe { &mut (*self.message.as_ptr()).header } + } + + /// Committed message contents. + /// + /// The message contents are available for the lifetime `'b`; the builder + /// cannot be used to modify them since they have been committed. + pub fn committed(&self) -> &'b [u8] { + // SAFETY: 'message.contents[..commit]' is immutably borrowed by + // 'self'. + unsafe { &(*self.message.as_ptr()).contents[..self.commit] } + } + + /// The appended but uncommitted contents of the message. + /// + /// The builder can modify or rewind these contents, so they are offered + /// with a short lifetime. + pub fn appended(&self) -> &[u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.commit..self.context.size; + unsafe { &(*self.message.as_ptr()).contents[range] } + } + + /// The appended but uncommitted contents of the message, mutably. + pub fn appended_mut(&mut self) -> &mut [u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.commit..self.context.size; + unsafe { &mut (*self.message.as_ptr()).contents[range] } + } + + /// Uninitialized space in the message buffer. + /// + /// This can be filled manually, then marked as initialized using + /// [`Self::mark_appended()`]. + pub fn uninitialized(&mut self) -> &mut [u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.context.size..self.context.max_size; + unsafe { &mut (*self.message.as_ptr()).contents[range] } + } + + /// The message with all committed contents. + /// + /// The header of the message can be modified by the builder, so the + /// returned reference has a short lifetime. The message contents can be + /// borrowed for a longer lifetime -- see [`Self::committed()`]. + pub fn message(&self) -> &Message { + // SAFETY: All of 'message' can be immutably borrowed by 'self'. + let message = unsafe { &*self.message.as_ptr() }; + let message = message.as_bytes(); + Message::ref_from_bytes_with_elems(message, self.commit) + .map_err(SizeError::from) + .expect("'message' represents a valid 'Message'") + } + + /// The message including any uncommitted contents. + /// + /// The header of the message can be modified by the builder, so the + /// returned reference has a short lifetime. The message contents can be + /// borrowed for a longer lifetime -- see [`Self::committed()`]. + pub fn cur_message(&self) -> &Message { + // SAFETY: All of 'message' can be immutably borrowed by 'self'. + let message = unsafe { &*self.message.as_ptr() }; + let message = message.as_bytes(); + Message::ref_from_bytes_with_elems(message, self.context.size) + .map_err(SizeError::from) + .expect("'message' represents a valid 'Message'") + } + + /// The builder context. + pub fn context(&self) -> &BuilderContext { + &*self.context + } + + /// Decompose this builder into raw parts. + /// + /// This returns three components: + /// + /// - The message buffer. The committed contents of the message (the + /// first `commit` bytes of the message contents) are borrowed immutably + /// for the lifetime `'b`. The remainder of the message buffer is + /// borrowed mutably for the lifetime `'b`. + /// + /// - Context for this builder. + /// + /// - The amount of data committed in the message (`commit`). + /// + /// The builder can be recomposed with [`Self::from_raw_parts()`]. + pub fn into_raw_parts( + self, + ) -> (NonNull, &'b mut BuilderContext, usize) { + // NOTE: The context has to be moved out carefully. + let (message, commit) = (self.message, self.commit); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let context = unsafe { ptr::read(ptr::addr_of!((*this).context)) }; + (message, context, commit) + } +} + +//--- Interaction + +impl<'b> Builder<'b> { + /// Rewind the builder, removing all committed content. + pub fn rewind(&mut self) { + self.context.size = self.commit; + } + + /// Commit all appended content. + pub fn commit(&mut self) { + self.commit = self.context.size; + } + + /// Mark bytes in the buffer as initialized. + /// + /// The given number of bytes from the beginning of + /// [`Self::uninitialized()`] will be marked as initialized, and will be + /// treated as appended content in the buffer. + /// + /// # Panics + /// + /// Panics if the uninitialized buffer is smaller than the given number of + /// initialized bytes. + pub fn mark_appended(&mut self, amount: usize) { + assert!(self.context.max_size - self.context.size >= amount); + self.context.size += amount; + } + + /// Delegate to a new builder. + /// + /// Any content committed by the builder will be added as uncommitted + /// content for this builder. + pub fn delegate(&mut self) -> Builder<'_> { + let commit = self.context.size; + unsafe { + Builder::from_raw_parts(self.message, &mut *self.context, commit) + } + } + + /// Limit the total message size. + /// + /// The message will not be allowed to exceed the given size, in bytes. + /// Only the message header and contents are counted; the enclosing UDP + /// or TCP packet size is not considered. If the message already exceeds + /// this size, a [`TruncationError`] is returned. + /// + /// This size will apply to all + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if self.context.size <= size { + self.context.max_size = size; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Append data of a known size using a closure. + /// + /// All the requested bytes must be initialized. If not enough free space + /// could be obtained, a [`TruncationError`] is returned. + pub fn append_with( + &mut self, + size: usize, + fill: impl FnOnce(&mut [u8]), + ) -> Result<(), TruncationError> { + self.uninitialized() + .get_mut(..size) + .ok_or(TruncationError) + .map(fill) + } + + /// Append some bytes. + /// + /// No name compression will be performed. + pub fn append_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), TruncationError> { + self.append_with(bytes.len(), |buffer| buffer.copy_from_slice(bytes)) + } + + /// Compress and append a domain name. + pub fn append_name( + &mut self, + name: &RevName, + ) -> Result<(), TruncationError> { + // TODO: Perform name compression. + self.append_with(name.len(), |mut buffer| { + // Write out the labels in the name in reverse. + for label in name.labels() { + let label_buffer; + let offset = buffer.len() - label.len() - 1; + (buffer, label_buffer) = buffer.split_at_mut(offset); + label_buffer[0] = label.len() as u8; + label_buffer[1..].copy_from_slice(label.as_bytes()); + } + }) + } +} + +//--- Drop + +impl Drop for Builder<'_> { + fn drop(&mut self) { + // Drop uncommitted content. + self.rewind(); + } +} + +//----------- BuilderContext ------------------------------------------------- + +/// Context for building a DNS message. +#[derive(Clone, Debug)] +pub struct BuilderContext { + // TODO: Name compression. + /// The current size of the message contents. + size: usize, + + /// The maximum size of the message contents. + max_size: usize, +} + +//--- Inspection + +impl BuilderContext { + /// The size of the message contents. + pub fn size(&self) -> usize { + self.size + } + + /// The maximum size of the message contents. + pub fn max_size(&self) -> usize { + self.max_size + } +} + +//--- Default + +impl Default for BuilderContext { + fn default() -> Self { + Self { + size: 0, + max_size: 65535 - core::mem::size_of::
(), + } + } +} diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index e0a1c4925..4aec6820d 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -2,6 +2,33 @@ use core::fmt; +mod builder; +pub use builder::{Builder, BuilderContext}; + +//----------- Message-aware building traits ---------------------------------- + +/// Building into a DNS message. +pub trait BuildIntoMessage { + // Append this value to the DNS message. + /// + /// If the byte string is long enough to fit the message, it is appended + /// using the given message builder and committed. Otherwise, a + /// [`TruncationError`] is returned. + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError>; +} + +impl BuildIntoMessage for &T { + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError> { + (**self).build_into_message(builder) + } +} + //----------- Low-level building traits -------------------------------------- /// Building into a byte string. From 993eda7bf2941b870dc9c8bcc017c1b459e8f531 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:44:17 +0100 Subject: [PATCH 029/111] [new_base/name/reversed] Impl building traits --- src/new_base/build/builder.rs | 16 +++------ src/new_base/name/reversed.rs | 62 +++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 7488f8f6a..93fda594b 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -10,7 +10,7 @@ use zerocopy::{FromBytes, IntoBytes, SizeError}; use crate::new_base::{name::RevName, Header, Message}; -use super::TruncationError; +use super::{BuildInto, TruncationError}; //----------- Builder -------------------------------------------------------- @@ -274,6 +274,7 @@ impl<'b> Builder<'b> { .get_mut(..size) .ok_or(TruncationError) .map(fill) + .map(|()| self.context.size += size) } /// Append some bytes. @@ -292,16 +293,9 @@ impl<'b> Builder<'b> { name: &RevName, ) -> Result<(), TruncationError> { // TODO: Perform name compression. - self.append_with(name.len(), |mut buffer| { - // Write out the labels in the name in reverse. - for label in name.labels() { - let label_buffer; - let offset = buffer.len() - label.len() - 1; - (buffer, label_buffer) = buffer.split_at_mut(offset); - label_buffer[0] = label.len() as u8; - label_buffer[1..].copy_from_slice(label.as_bytes()); - } - }) + name.build_into(self.uninitialized())?; + self.mark_appended(name.len()); + Ok(()) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 60a640f91..513a72582 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -12,6 +12,7 @@ use zerocopy::IntoBytes; use zerocopy_derive::*; use crate::new_base::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -97,6 +98,45 @@ impl RevName { } } +//--- Building into DNS messages + +impl BuildIntoMessage for RevName { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_name(self)?; + builder.commit(); + Ok(()) + } +} + +//--- Building into byte strings + +impl BuildInto for RevName { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if bytes.len() < self.len() { + return Err(TruncationError); + } + + let (mut buffer, rest) = bytes.split_at_mut(self.len()); + + // Write out the labels in the name in reverse. + for label in self.labels() { + let label_buffer; + let offset = buffer.len() - label.len() - 1; + (buffer, label_buffer) = buffer.split_at_mut(offset); + label_buffer[0] = label.len() as u8; + label_buffer[1..].copy_from_slice(label.as_bytes()); + } + + Ok(rest) + } +} + //--- Equality impl PartialEq for RevName { @@ -332,6 +372,17 @@ fn parse_segment<'a>( } } +//--- Building into DNS messages + +impl BuildIntoMessage for RevNameBuf { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + (**self).build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for RevNameBuf { @@ -369,6 +420,17 @@ impl<'a> ParseFrom<'a> for RevNameBuf { } } +//--- Building into byte strings + +impl BuildInto for RevNameBuf { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_into(bytes) + } +} + //--- Interaction impl RevNameBuf { From d40749234bf3ea75a18cdffaf2d0132c313c6f4b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:53:54 +0100 Subject: [PATCH 030/111] [new_base/build] Add convenience impls for '[u8]' --- src/new_base/build/mod.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 4aec6820d..108cc76f0 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -29,6 +29,17 @@ impl BuildIntoMessage for &T { } } +impl BuildIntoMessage for [u8] { + fn build_into_message( + &self, + mut builder: Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(self)?; + builder.commit(); + Ok(()) + } +} + //----------- Low-level building traits -------------------------------------- /// Building into a byte string. @@ -53,6 +64,21 @@ impl BuildInto for &T { } } +impl BuildInto for [u8] { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if self.len() <= bytes.len() { + let (bytes, rest) = bytes.split_at_mut(self.len()); + bytes.copy_from_slice(self); + Ok(rest) + } else { + Err(TruncationError) + } + } +} + //----------- TruncationError ------------------------------------------------ /// A DNS message did not fit in a buffer. From 55c7854e874ae82a36bbb67ace8ad1ab8dce823c Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:54:14 +0100 Subject: [PATCH 031/111] [new_base/question] Impl building traits --- src/new_base/question.rs | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 121eedff4..f173a664f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,10 +2,11 @@ use core::ops::Range; -use zerocopy::network_endian::U16; +use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, @@ -77,6 +78,23 @@ where } } +//--- Building into DNS messages + +impl BuildIntoMessage for Question +where + N: BuildIntoMessage, +{ + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.qname.build_into_message(builder.delegate())?; + builder.append_bytes(self.qtype.as_bytes())?; + builder.append_bytes(self.qclass.as_bytes())?; + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N> SplitFrom<'a> for Question @@ -103,6 +121,23 @@ where } } +//--- Building into byte strings + +impl BuildInto for Question +where + N: BuildInto, +{ + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.qname.build_into(bytes)?; + bytes = self.qtype.as_bytes().build_into(bytes)?; + bytes = self.qclass.as_bytes().build_into(bytes)?; + Ok(bytes) + } +} + //----------- QType ---------------------------------------------------------- /// The type of a question. From 74e67836698385f4c7c994d5eea6f539b42cf743 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:54:24 +0100 Subject: [PATCH 032/111] [new_base/record] Support building and overhaul parsing --- src/new_base/record.rs | 179 ++++++++++++++++++++++++++++++++++------- 1 file changed, 149 insertions(+), 30 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 42336dc6f..79687cece 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -7,11 +7,12 @@ use core::{ use zerocopy::{ network_endian::{U16, U32}, - FromBytes, IntoBytes, + FromBytes, IntoBytes, SizeError, }; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, @@ -64,19 +65,104 @@ impl Record { } } +//--- Parsing from DNS messages + +impl<'a, N, D> SplitFromMessage<'a> for Record +where + N: SplitFromMessage<'a>, + D: ParseFromMessage<'a>, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (rname, rest) = N::split_from_message(message, start)?; + let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; + let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; + let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let (&size, rest) = <&U16>::split_from_message(message, rest)?; + let size: usize = size.get().into(); + let rdata = if message.as_bytes().len() - rest >= size { + D::parse_from_message(message, rest..rest + size)? + } else { + return Err(ParseError); + }; + + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest + size)) + } +} + +impl<'a, N, D> ParseFromMessage<'a> for Record +where + N: SplitFromMessage<'a>, + D: ParseFromMessage<'a>, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = &message.as_bytes()[..range.end]; + let message = Message::ref_from_bytes(message) + .map_err(SizeError::from) + .expect("The input range ends past the message header"); + + let (this, rest) = Self::split_from_message(message, range.start)?; + + if rest == range.end { + Ok(this) + } else { + Err(ParseError) + } + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Record +where + N: BuildIntoMessage, + D: BuildIntoMessage, +{ + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.rname.build_into_message(builder.delegate())?; + builder.append_bytes(self.rtype.as_bytes())?; + builder.append_bytes(self.rclass.as_bytes())?; + builder.append_bytes(self.ttl.as_bytes())?; + + // The offset of the record data size. + let offset = builder.appended().len(); + builder.append_bytes(&0u16.to_be_bytes())?; + self.rdata.build_into_message(builder.delegate())?; + let size = builder.appended().len() - 2 - offset; + let size = + u16::try_from(size).expect("the record data never exceeds 64KiB"); + builder.appended_mut()[offset..offset + 2] + .copy_from_slice(&size.to_be_bytes()); + + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N, D> SplitFrom<'a> for Record where N: SplitFrom<'a>, - D: SplitFrom<'a>, + D: ParseFrom<'a>, { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (rdata, rest) = D::split_from(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size: usize = size.get().into(); + let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + let rdata = D::parse_from(rdata)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -92,12 +178,44 @@ where let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let rdata = D::parse_from(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size: usize = size.get().into(); + let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + let rdata = D::parse_from(rdata)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } } +//--- Building into byte strings + +impl BuildInto for Record +where + N: BuildInto, + D: BuildInto, +{ + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.rname.build_into(bytes)?; + bytes = self.rtype.as_bytes().build_into(bytes)?; + bytes = self.rclass.as_bytes().build_into(bytes)?; + bytes = self.ttl.as_bytes().build_into(bytes)?; + + let (size, bytes) = + ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; + let bytes_len = bytes.len(); + + let rest = self.rdata.build_into(bytes)?; + *size = u16::try_from(bytes_len - rest.len()) + .expect("the record data never exceeds 64KiB") + .into(); + + Ok(rest) + } +} + //----------- RType ---------------------------------------------------------- /// The type of a record. @@ -194,18 +312,6 @@ impl UnparsedRecordData { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a UnparsedRecordData { - fn split_from_message( - message: &'a Message, - start: usize, - ) -> Result<(Self, usize), ParseError> { - let message = message.as_bytes(); - let bytes = message.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_from(bytes)?; - Ok((this, message.len() - rest.len())) - } -} - impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { fn parse_from_message( message: &'a Message, @@ -217,26 +323,39 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { } } -//--- Parsing from bytes +//--- Building into DNS messages -impl<'a> SplitFrom<'a> for &'a UnparsedRecordData { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (size, rest) = U16::read_from_prefix(bytes)?; - let size = size.get() as usize; - let (data, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; - // SAFETY: 'data.len() == size' which is a 'u16'. - let this = unsafe { UnparsedRecordData::new_unchecked(data) }; - Ok((this, rest)) +impl BuildIntoMessage for UnparsedRecordData { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.0.build_into_message(builder) } } +//--- Parsing from bytes + impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { fn parse_from(bytes: &'a [u8]) -> Result { - let (size, rest) = U16::read_from_prefix(bytes)?; - let size = size.get() as usize; - let data = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - // SAFETY: 'data.len() == size' which is a 'u16'. - Ok(unsafe { UnparsedRecordData::new_unchecked(data) }) + if bytes.len() > 65535 { + // Too big to fit in an 'UnparsedRecordData'. + return Err(ParseError); + } + + // SAFETY: 'bytes.len()' fits within a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) + } +} + +//--- Building into byte strings + +impl BuildInto for UnparsedRecordData { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.0.build_into(bytes) } } From 6772427ab1b31b9d7d05e5817aac47aba7ac5621 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:17:58 +0100 Subject: [PATCH 033/111] [new_base/charstr] Support building --- src/new_base/charstr.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index dfe2bc8f9..fdd5e5bdf 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -6,6 +6,7 @@ use zerocopy::IntoBytes; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -48,6 +49,20 @@ impl<'a> ParseFromMessage<'a> for &'a CharStr { } } +//--- Building into DNS messages + +impl BuildIntoMessage for CharStr { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(&[self.octets.len() as u8])?; + builder.append_bytes(&self.octets)?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for &'a CharStr { @@ -75,6 +90,20 @@ impl<'a> ParseFrom<'a> for &'a CharStr { } } +//--- Building into byte strings + +impl BuildInto for CharStr { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (length, bytes) = + bytes.split_first_mut().ok_or(TruncationError)?; + *length = self.octets.len() as u8; + self.octets.build_into(bytes) + } +} + //--- Equality impl PartialEq for CharStr { From 59c33b2c5b2409ff7642de74ac9810e2a185ca41 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:18:08 +0100 Subject: [PATCH 034/111] [new_rdata/rfc1035] Impl building traits --- src/new_rdata/rfc1035.rs | 224 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 4e09c9727..dc25f0007 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -9,6 +9,7 @@ use zerocopy::{ use zerocopy_derive::*; use crate::new_base::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -73,6 +74,28 @@ impl fmt::Display for A { } } +//--- Building into DNS messages + +impl BuildIntoMessage for A { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for A { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} + //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. @@ -94,6 +117,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Ns { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { @@ -102,6 +136,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { } } +//--- Building into bytes + +impl BuildInto for Ns { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -123,6 +168,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Cname { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { @@ -131,6 +187,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { } } +//--- Building into bytes + +impl BuildInto for Cname { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. @@ -185,6 +252,25 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Soa { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.mname.build_into_message(builder.delegate())?; + self.rname.build_into_message(builder.delegate())?; + builder.append_bytes(self.serial.as_bytes())?; + builder.append_bytes(self.refresh.as_bytes())?; + builder.append_bytes(self.retry.as_bytes())?; + builder.append_bytes(self.expire.as_bytes())?; + builder.append_bytes(self.minimum.as_bytes())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { @@ -209,6 +295,24 @@ impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { } } +//--- Building into byte strings + +impl BuildInto for Soa { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.mname.build_into(bytes)?; + bytes = self.rname.build_into(bytes)?; + bytes = self.serial.as_bytes().build_into(bytes)?; + bytes = self.refresh.as_bytes().build_into(bytes)?; + bytes = self.retry.as_bytes().build_into(bytes)?; + bytes = self.expire.as_bytes().build_into(bytes)?; + bytes = self.minimum.as_bytes().build_into(bytes)?; + Ok(bytes) + } +} + //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. @@ -253,6 +357,28 @@ impl fmt::Debug for Wks { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Wks { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Wks { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} + //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. @@ -274,6 +400,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Ptr { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { @@ -282,6 +419,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { } } +//--- Building into bytes + +impl BuildInto for Ptr { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. @@ -309,6 +457,20 @@ impl<'a> ParseFromMessage<'a> for Hinfo<'a> { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Hinfo<'_> { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.cpu.build_into_message(builder.delegate())?; + self.os.build_into_message(builder.delegate())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a> ParseFrom<'a> for Hinfo<'a> { @@ -319,6 +481,19 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { } } +//--- Building into bytes + +impl BuildInto for Hinfo<'_> { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.cpu.build_into(bytes)?; + bytes = self.os.build_into(bytes)?; + Ok(bytes) + } +} + //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. @@ -349,6 +524,20 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Mx { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(self.preference.as_bytes())?; + self.exchange.build_into_message(builder.delegate())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { @@ -362,6 +551,19 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { } } +//--- Building into byte strings + +impl BuildInto for Mx { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.preference.as_bytes().build_into(bytes)?; + bytes = self.exchange.build_into(bytes)?; + Ok(bytes) + } +} + //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. @@ -389,6 +591,17 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Txt { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.content.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a> ParseFrom<'a> for &'a Txt { @@ -403,3 +616,14 @@ impl<'a> ParseFrom<'a> for &'a Txt { Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) } } + +//--- Building into byte strings + +impl BuildInto for Txt { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.content.build_into(bytes) + } +} From 2d845732dd74d34f76f02462cd3f7b9f2c61962a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:57:13 +0100 Subject: [PATCH 035/111] [build/builder] Improve unclear documentation --- src/new_base/build/builder.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 93fda594b..da6db0ba4 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -70,6 +70,9 @@ impl<'b> Builder<'b> { /// Initialize an empty [`Builder`]. /// + /// The message header is left uninitialized. Use [`Self::header_mut()`] + /// to initialize it. + /// /// # Panics /// /// Panics if the buffer is less than 12 bytes long (which is the minimum @@ -251,10 +254,17 @@ impl<'b> Builder<'b> { /// or TCP packet size is not considered. If the message already exceeds /// this size, a [`TruncationError`] is returned. /// - /// This size will apply to all + /// This size will apply to all builders for this message (including those + /// that delegated to `self`). It will not be automatically revoked if + /// message building fails. + /// + /// # Panics + /// + /// Panics if the given size is less than 12 bytes. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - if self.context.size <= size { - self.context.max_size = size; + assert!(size >= 12); + if self.context.size <= size - 12 { + self.context.max_size = size - 12; Ok(()) } else { Err(TruncationError) From a8433c91bc57ec5fff793ed8ee996cd8b778a802 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:46:32 +0100 Subject: [PATCH 036/111] [new_rdata/rfc1035] gate 'Ipv4Addr' behind 'std' --- src/new_rdata/rfc1035.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index dc25f0007..6bc8eead2 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,6 +1,9 @@ //! Core record data types. -use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; +use core::{fmt, ops::Range, str::FromStr}; + +#[cfg(feature = "std")] +use std::net::Ipv4Addr; use zerocopy::{ network_endian::{U16, U32}, @@ -42,6 +45,7 @@ pub struct A { //--- Converting to and from 'Ipv4Addr' +#[cfg(feature = "std")] impl From for A { fn from(value: Ipv4Addr) -> Self { Self { @@ -50,6 +54,7 @@ impl From for A { } } +#[cfg(feature = "std")] impl From for Ipv4Addr { fn from(value: A) -> Self { Self::from(value.octets) @@ -58,6 +63,7 @@ impl From for Ipv4Addr { //--- Parsing from a string +#[cfg(feature = "std")] impl FromStr for A { type Err = ::Err; @@ -68,6 +74,7 @@ impl FromStr for A { //--- Formatting +#[cfg(feature = "std")] impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Ipv4Addr::from(*self).fmt(f) From 86f14bb7a81477baf954e98a0f5324be572ff130 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:50:53 +0100 Subject: [PATCH 037/111] [new_rdata/rfc1035] Gate more things under 'std' --- src/new_rdata/rfc1035.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 6bc8eead2..a05d3cb97 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,6 +1,8 @@ //! Core record data types. -use core::{fmt, ops::Range, str::FromStr}; +use core::ops::Range; +#[cfg(feature = "std")] +use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv4Addr; @@ -338,6 +340,7 @@ pub struct Wks { //--- Formatting +#[cfg(feature = "std")] impl fmt::Debug for Wks { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct Ports<'a>(&'a [u8]); From 072827597b2dba6edd85d4a3478e128d6f9da4ac Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:55:10 +0100 Subject: [PATCH 038/111] [new_base/build/builder] Remove unnecessary explicit lifetime in impl --- src/new_base/build/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index da6db0ba4..75a9cfc69 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -210,7 +210,7 @@ impl<'b> Builder<'b> { //--- Interaction -impl<'b> Builder<'b> { +impl Builder<'_> { /// Rewind the builder, removing all committed content. pub fn rewind(&mut self) { self.context.size = self.commit; From f73ca63e1be6832277ee39cbb3c8539a24e316c9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Sat, 28 Dec 2024 09:51:16 +0100 Subject: [PATCH 039/111] [new_rdata/rfc1035] Support 'Display' outside 'std' --- src/new_rdata/rfc1035.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index a05d3cb97..f42cb3180 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,8 +1,9 @@ //! Core record data types. -use core::ops::Range; +use core::{fmt, ops::Range}; + #[cfg(feature = "std")] -use core::{fmt, str::FromStr}; +use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; @@ -76,10 +77,10 @@ impl FromStr for A { //--- Formatting -#[cfg(feature = "std")] impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - Ipv4Addr::from(*self).fmt(f) + let [a, b, c, d] = self.octets; + write!(f, "{a}.{b}.{c}.{d}") } } @@ -340,7 +341,6 @@ pub struct Wks { //--- Formatting -#[cfg(feature = "std")] impl fmt::Debug for Wks { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct Ports<'a>(&'a [u8]); @@ -360,7 +360,7 @@ impl fmt::Debug for Wks { } f.debug_struct("Wks") - .field("address", &Ipv4Addr::from(self.address)) + .field("address", &format_args!("{}", self.address)) .field("protocol", &self.protocol) .field("ports", &Ports(&self.ports)) .finish() From 1133b4c0860cef74a095dfa10caaaa6a1e49fcb7 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 14:08:55 +0100 Subject: [PATCH 040/111] [new_rdata] Inline 'rfc1035' and support 'rfc3596' --- src/new_rdata/mod.rs | 6 ++- src/new_rdata/rfc3596.rs | 98 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 src/new_rdata/rfc3596.rs diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 54afb39ee..248c02d37 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,3 +1,7 @@ //! Record data types. -pub mod rfc1035; +mod rfc1035; +pub use rfc1035::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; + +mod rfc3596; +pub use rfc3596::Aaaa; diff --git a/src/new_rdata/rfc3596.rs b/src/new_rdata/rfc3596.rs new file mode 100644 index 000000000..9a474aab1 --- /dev/null +++ b/src/new_rdata/rfc3596.rs @@ -0,0 +1,98 @@ +//! IPv6 record data types. + +#[cfg(feature = "std")] +use core::{fmt, str::FromStr}; + +#[cfg(feature = "std")] +use std::net::Ipv6Addr; + +use zerocopy::IntoBytes; +use zerocopy_derive::*; + +use crate::new_base::build::{ + self, BuildInto, BuildIntoMessage, TruncationError, +}; + +//----------- Aaaa ----------------------------------------------------------- + +/// The IPv6 address of a host responsible for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Aaaa { + /// The IPv6 address octets. + pub octets: [u8; 16], +} + +//--- Converting to and from 'Ipv6Addr' + +#[cfg(feature = "std")] +impl From for Aaaa { + fn from(value: Ipv6Addr) -> Self { + Self { + octets: value.octets(), + } + } +} + +#[cfg(feature = "std")] +impl From for Ipv6Addr { + fn from(value: Aaaa) -> Self { + Self::from(value.octets) + } +} + +//--- Parsing from a string + +#[cfg(feature = "std")] +impl FromStr for Aaaa { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ipv6Addr::from_str(s).map(Aaaa::from) + } +} + +//--- Formatting + +#[cfg(feature = "std")] +impl fmt::Display for Aaaa { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ipv6Addr::from(*self).fmt(f) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Aaaa { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Aaaa { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} From 1a95ae811cac3264d8ba3e7ece2f814e37da80de Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 14:10:51 +0100 Subject: [PATCH 041/111] [new_rdata] Rename submodules with more intuitive names --- src/new_rdata/{rfc1035.rs => basic.rs} | 2 ++ src/new_rdata/{rfc3596.rs => ipv6.rs} | 2 ++ src/new_rdata/mod.rs | 8 ++++---- 3 files changed, 8 insertions(+), 4 deletions(-) rename src/new_rdata/{rfc1035.rs => basic.rs} (99%) rename src/new_rdata/{rfc3596.rs => ipv6.rs} (96%) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/basic.rs similarity index 99% rename from src/new_rdata/rfc1035.rs rename to src/new_rdata/basic.rs index f42cb3180..2abb38cf4 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/basic.rs @@ -1,4 +1,6 @@ //! Core record data types. +//! +//! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). use core::{fmt, ops::Range}; diff --git a/src/new_rdata/rfc3596.rs b/src/new_rdata/ipv6.rs similarity index 96% rename from src/new_rdata/rfc3596.rs rename to src/new_rdata/ipv6.rs index 9a474aab1..606486d08 100644 --- a/src/new_rdata/rfc3596.rs +++ b/src/new_rdata/ipv6.rs @@ -1,4 +1,6 @@ //! IPv6 record data types. +//! +//! See [RFC 3596](https://datatracker.ietf.org/doc/html/rfc3596). #[cfg(feature = "std")] use core::{fmt, str::FromStr}; diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 248c02d37..bbec94d3d 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,7 @@ //! Record data types. -mod rfc1035; -pub use rfc1035::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; +mod basic; +pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; -mod rfc3596; -pub use rfc3596::Aaaa; +mod ipv6; +pub use ipv6::Aaaa; From 6bf26c5dd538ff0f46e13bbfafc5c1a27ab2abee Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 15:20:54 +0100 Subject: [PATCH 042/111] [new_base] Set up basic EDNS support Instead of 'new_base::opt', EDNS is now granted its own top-level module. This matches up well with 'crate::tsig'. --- src/lib.rs | 6 +- src/new_base/message.rs | 2 +- src/new_base/record.rs | 8 +- src/new_edns/mod.rs | 189 ++++++++++++++++++++++++++++++++++++++++ src/new_rdata/edns.rs | 55 ++++++++++++ src/new_rdata/mod.rs | 3 + 6 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 src/new_edns/mod.rs create mode 100644 src/new_rdata/edns.rs diff --git a/src/lib.rs b/src/lib.rs index b2f7ac66c..4d08972f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,8 +193,6 @@ extern crate core; pub mod base; pub mod dep; pub mod net; -pub mod new_base; -pub mod new_rdata; pub mod rdata; pub mod resolv; pub mod sign; @@ -205,3 +203,7 @@ pub mod validate; pub mod validator; pub mod zonefile; pub mod zonetree; + +pub mod new_base; +pub mod new_edns; +pub mod new_rdata; diff --git a/src/new_base/message.rs b/src/new_base/message.rs index c07d605fa..e60ae76ff 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -186,7 +186,7 @@ impl fmt::Debug for HeaderFlags { .field("should_recurse (rd)", &self.should_recurse()) .field("can_recurse (ra)", &self.can_recurse()) .field("rcode", &self.rcode()) - .field("bits", &self.inner.get()) + .field("bits", &self.bits()) .finish() } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 79687cece..9522d80d3 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -317,9 +317,11 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { message: &'a Message, range: Range, ) -> Result { - let message = message.as_bytes(); - let bytes = message.get(range).ok_or(ParseError)?; - Self::parse_from(bytes) + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) } } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs new file mode 100644 index 000000000..014ac7104 --- /dev/null +++ b/src/new_edns/mod.rs @@ -0,0 +1,189 @@ +//! Support for Extended DNS (RFC 6891). +//! +//! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). + +use core::fmt; + +use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; +use zerocopy_derive::*; + +use crate::{ + new_base::{ + parse::{ParseError, SplitFrom, SplitFromMessage}, + Message, + }, + new_rdata::Opt, +}; + +//----------- EdnsRecord ----------------------------------------------------- + +/// An Extended DNS record. +#[derive(Clone)] +pub struct EdnsRecord<'a> { + /// The largest UDP payload the DNS client supports, in bytes. + pub max_udp_payload: U16, + + /// An extension to the response code of the DNS message. + pub ext_rcode: u8, + + /// The Extended DNS version used by this message. + pub version: u8, + + /// Flags describing the message. + pub flags: EdnsFlags, + + /// Extended DNS options. + pub options: &'a Opt, +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, message.as_bytes().len() - rest.len())) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for EdnsRecord<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + // Strip the record name (root) and the record type. + let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; + + let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; + let (&ext_rcode, rest) = <&u8>::split_from(rest)?; + let (&version, rest) = <&u8>::split_from(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + + // Split the record size and data. + let (&size, rest) = <&U16>::split_from(rest)?; + let size: usize = size.get().into(); + let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; + + Ok(( + Self { + max_udp_payload, + ext_rcode, + version, + flags, + options, + }, + rest, + )) + } +} + +//----------- EdnsFlags ------------------------------------------------------ + +/// Extended DNS flags describing a message. +#[derive( + Copy, + Clone, + Default, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct EdnsFlags { + inner: U16, +} + +//--- Interaction + +impl EdnsFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether the client supports DNSSEC. + /// + /// See [RFC 3225](https://datatracker.ietf.org/doc/html/rfc3225). + pub fn is_dnssec_ok(&self) -> bool { + self.get_flag(15) + } + + /// Indicate support for DNSSEC to the server. + /// + /// See [RFC 3225](https://datatracker.ietf.org/doc/html/rfc3225). + pub fn set_dnssec_ok(self, value: bool) -> Self { + self.set_flag(15, value) + } +} + +//--- Formatting + +impl fmt::Debug for EdnsFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EdnsFlags") + .field("dnssec_ok (do)", &self.is_dnssec_ok()) + .field("bits", &self.bits()) + .finish() + } +} + +//----------- EdnsOption ----------------------------------------------------- + +/// An Extended DNS option. +#[derive(Debug)] +#[non_exhaustive] +pub enum EdnsOption<'b> { + /// An unknown option. + Unknown(OptionCode, &'b UnknownOption), +} + +//----------- OptionCode ----------------------------------------------------- + +/// An Extended DNS option code. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct OptionCode { + /// The option code. + pub code: U16, +} + +//----------- UnknownOption -------------------------------------------------- + +/// Data for an unknown Extended DNS option. +#[derive(Debug, FromBytes, IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct UnknownOption { + /// The unparsed option data. + pub octets: [u8], +} diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs new file mode 100644 index 000000000..89e146062 --- /dev/null +++ b/src/new_rdata/edns.rs @@ -0,0 +1,55 @@ +//! Record data types for Extended DNS. +//! +//! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). + +use zerocopy_derive::*; + +use crate::new_base::build::{ + self, BuildInto, BuildIntoMessage, TruncationError, +}; + +//----------- Opt ------------------------------------------------------------ + +/// Extended DNS options. +#[derive( + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] // 'derive(KnownLayout)' doesn't work with 'repr(transparent)'. +pub struct Opt { + /// The raw serialized options. + contents: [u8], +} + +// TODO: Parsing the EDNS options. +// TODO: Formatting. + +//--- Building into DNS messages + +impl BuildIntoMessage for Opt { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.contents.build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Opt { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.contents.build_into(bytes) + } +} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index bbec94d3d..1aad4cca0 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -5,3 +5,6 @@ pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; mod ipv6; pub use ipv6::Aaaa; + +mod edns; +pub use edns::Opt; From aa0c59036ea53804ade672b6eeb6b8c9a0c9231e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 15:45:05 +0100 Subject: [PATCH 043/111] [new_base/record] Add trait 'ParseRecordData' --- src/new_base/mod.rs | 3 +- src/new_base/record.rs | 73 +++++++++++++++++++++++++----------------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 899225cf8..3c2e34068 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -14,7 +14,8 @@ pub use question::{QClass, QType, Question, UnparsedQuestion}; mod record; pub use record::{ - RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, + ParseRecordData, RClass, RType, Record, UnparsedRecord, + UnparsedRecordData, TTL, }; //--- Elements of DNS messages diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 9522d80d3..6f3a1daa8 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -70,7 +70,7 @@ impl Record { impl<'a, N, D> SplitFromMessage<'a> for Record where N: SplitFromMessage<'a>, - D: ParseFromMessage<'a>, + D: ParseRecordData<'a>, { fn split_from_message( message: &'a Message, @@ -83,7 +83,7 @@ where let (&size, rest) = <&U16>::split_from_message(message, rest)?; let size: usize = size.get().into(); let rdata = if message.as_bytes().len() - rest >= size { - D::parse_from_message(message, rest..rest + size)? + D::parse_record_data(message, rest..rest + size, rtype)? } else { return Err(ParseError); }; @@ -95,7 +95,7 @@ where impl<'a, N, D> ParseFromMessage<'a> for Record where N: SplitFromMessage<'a>, - D: ParseFromMessage<'a>, + D: ParseRecordData<'a>, { fn parse_from_message( message: &'a Message, @@ -152,7 +152,7 @@ where impl<'a, N, D> SplitFrom<'a> for Record where N: SplitFrom<'a>, - D: ParseFrom<'a>, + D: ParseRecordData<'a>, { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; @@ -162,7 +162,7 @@ where let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; - let rdata = D::parse_from(rdata)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -171,7 +171,7 @@ where impl<'a, N, D> ParseFrom<'a> for Record where N: SplitFrom<'a>, - D: ParseFrom<'a>, + D: ParseRecordData<'a>, { fn parse_from(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_from(bytes)?; @@ -181,7 +181,7 @@ where let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - let rdata = D::parse_from(rdata)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -288,6 +288,24 @@ pub struct TTL { pub value: U32, } +//----------- ParseRecordData ------------------------------------------------ + +/// Parsing DNS record data. +pub trait ParseRecordData<'a>: Sized { + /// Parse DNS record data of the given type from a DNS message. + fn parse_record_data( + message: &'a Message, + range: Range, + rtype: RType, + ) -> Result; + + /// Parse DNS record data of the given type from a byte string. + fn parse_record_data_bytes( + bytes: &'a [u8], + rtype: RType, + ) -> Result; +} + //----------- UnparsedRecordData --------------------------------------------- /// Unparsed DNS record data. @@ -310,18 +328,29 @@ impl UnparsedRecordData { } } -//--- Parsing from DNS messages +//--- Parsing record data -impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { - fn parse_from_message( +impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { + fn parse_record_data( message: &'a Message, range: Range, + rtype: RType, + ) -> Result { + let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + Self::parse_record_data_bytes(bytes, rtype) + } + + fn parse_record_data_bytes( + bytes: &'a [u8], + _rtype: RType, ) -> Result { - message - .as_bytes() - .get(range) - .ok_or(ParseError) - .and_then(Self::parse_from) + if bytes.len() > 65535 { + // Too big to fit in an 'UnparsedRecordData'. + return Err(ParseError); + } + + // SAFETY: 'bytes.len()' fits within a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) } } @@ -336,20 +365,6 @@ impl BuildIntoMessage for UnparsedRecordData { } } -//--- Parsing from bytes - -impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { - fn parse_from(bytes: &'a [u8]) -> Result { - if bytes.len() > 65535 { - // Too big to fit in an 'UnparsedRecordData'. - return Err(ParseError); - } - - // SAFETY: 'bytes.len()' fits within a 'u16'. - Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) - } -} - //--- Building into byte strings impl BuildInto for UnparsedRecordData { From d11f9e577d41fe1bf70a993337b46f6f8a16a155 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:50:22 +0100 Subject: [PATCH 044/111] [new_base/parse] Make 'Split*' imply 'Parse*' --- src/new_base/name/label.rs | 10 +++++++- src/new_base/parse/mod.rs | 4 ++-- src/new_edns/mod.rs | 49 ++++++++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index b93b32f80..087692047 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use zerocopy_derive::*; -use crate::new_base::parse::{ParseError, SplitFrom}; +use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; //----------- Label ---------------------------------------------------------- @@ -65,6 +65,14 @@ impl<'a> SplitFrom<'a> for &'a Label { } } +impl<'a> ParseFrom<'a> for &'a Label { + fn parse_from(bytes: &'a [u8]) -> Result { + Self::split_from(bytes).and_then(|(this, rest)| { + rest.is_empty().then_some(this).ok_or(ParseError) + }) + } +} + //--- Inspection impl Label { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 022ff9df2..31ad191ba 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -18,7 +18,7 @@ use super::Message; //----------- Message-aware parsing traits ----------------------------------- /// A type that can be parsed from a DNS message. -pub trait SplitFromMessage<'a>: Sized { +pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { /// Parse a value of [`Self`] from the start of a byte string within a /// particular DNS message. /// @@ -80,7 +80,7 @@ where //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. -pub trait SplitFrom<'a>: Sized { +pub trait SplitFrom<'a>: Sized + ParseFrom<'a> { /// Parse a value of [`Self`] from the start of the byte string. /// /// If parsing is successful, the parsed value and the rest of the string diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 014ac7104..f4d12f9c1 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -2,14 +2,17 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use core::fmt; +use core::{fmt, ops::Range}; use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; use zerocopy_derive::*; use crate::{ new_base::{ - parse::{ParseError, SplitFrom, SplitFromMessage}, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, + SplitFromMessage, + }, Message, }, new_rdata::Opt, @@ -49,6 +52,19 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { } } +impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for EdnsRecord<'a> { @@ -79,6 +95,31 @@ impl<'a> SplitFrom<'a> for EdnsRecord<'a> { } } +impl<'a> ParseFrom<'a> for EdnsRecord<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + // Strip the record name (root) and the record type. + let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; + + let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; + let (&ext_rcode, rest) = <&u8>::split_from(rest)?; + let (&version, rest) = <&u8>::split_from(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + + // Split the record size and data. + let (&size, rest) = <&U16>::split_from(rest)?; + let size: usize = size.get().into(); + let options = Opt::ref_from_bytes_with_elems(rest, size)?; + + Ok(Self { + max_udp_payload, + ext_rcode, + version, + flags, + options, + }) + } +} + //----------- EdnsFlags ------------------------------------------------------ /// Extended DNS flags describing a message. @@ -181,8 +222,8 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, FromBytes, IntoBytes, Immutable, Unaligned)] -#[repr(transparent)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C)] pub struct UnknownOption { /// The unparsed option data. pub octets: [u8], From ff7d9136b69b8bca98d26366d0ce0f2b86bf1ff7 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:51:57 +0100 Subject: [PATCH 045/111] [new_base/record] Add a default for 'parse_record_data()' --- src/new_base/record.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 6f3a1daa8..93038f8d9 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -297,7 +297,10 @@ pub trait ParseRecordData<'a>: Sized { message: &'a Message, range: Range, rtype: RType, - ) -> Result; + ) -> Result { + let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + Self::parse_record_data_bytes(bytes, rtype) + } /// Parse DNS record data of the given type from a byte string. fn parse_record_data_bytes( @@ -331,15 +334,6 @@ impl UnparsedRecordData { //--- Parsing record data impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { - fn parse_record_data( - message: &'a Message, - range: Range, - rtype: RType, - ) -> Result { - let bytes = message.as_bytes().get(range).ok_or(ParseError)?; - Self::parse_record_data_bytes(bytes, rtype) - } - fn parse_record_data_bytes( bytes: &'a [u8], _rtype: RType, From ba260f0e60f830671ced8313a9e4d92290ca0e68 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:53:00 +0100 Subject: [PATCH 046/111] [new_rdata/basic] Use more capitalization in record data type names --- src/new_rdata/basic.rs | 63 ++++++++++++++++++++++++++++++++++-------- src/new_rdata/mod.rs | 2 +- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 2abb38cf4..cf6a99c39 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -164,14 +164,14 @@ impl BuildInto for Ns { /// The canonical name for this domain. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct Cname { +pub struct CName { /// The canonical name. pub name: N, } //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { fn parse_from_message( message: &'a Message, range: Range, @@ -182,7 +182,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { //--- Building into DNS messages -impl BuildIntoMessage for Cname { +impl BuildIntoMessage for CName { fn build_into_message( &self, builder: build::Builder<'_>, @@ -193,7 +193,7 @@ impl BuildIntoMessage for Cname { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for CName { fn parse_from(bytes: &'a [u8]) -> Result { N::parse_from(bytes).map(|name| Self { name }) } @@ -201,7 +201,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { //--- Building into bytes -impl BuildInto for Cname { +impl BuildInto for CName { fn build_into<'b>( &self, bytes: &'b mut [u8], @@ -442,11 +442,11 @@ impl BuildInto for Ptr { } } -//----------- Hinfo ---------------------------------------------------------- +//----------- HInfo ---------------------------------------------------------- /// Information about the host computer. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Hinfo<'a> { +pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -456,7 +456,7 @@ pub struct Hinfo<'a> { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for Hinfo<'a> { +impl<'a> ParseFromMessage<'a> for HInfo<'a> { fn parse_from_message( message: &'a Message, range: Range, @@ -471,7 +471,7 @@ impl<'a> ParseFromMessage<'a> for Hinfo<'a> { //--- Building into DNS messages -impl BuildIntoMessage for Hinfo<'_> { +impl BuildIntoMessage for HInfo<'_> { fn build_into_message( &self, mut builder: build::Builder<'_>, @@ -485,7 +485,7 @@ impl BuildIntoMessage for Hinfo<'_> { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for Hinfo<'a> { +impl<'a> ParseFrom<'a> for HInfo<'a> { fn parse_from(bytes: &'a [u8]) -> Result { let (cpu, rest) = <&CharStr>::split_from(bytes)?; let os = <&CharStr>::parse_from(rest)?; @@ -495,7 +495,7 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { //--- Building into bytes -impl BuildInto for Hinfo<'_> { +impl BuildInto for HInfo<'_> { fn build_into<'b>( &self, mut bytes: &'b mut [u8], @@ -586,7 +586,23 @@ pub struct Txt { content: [u8], } -// TODO: Support for iterating over the contained 'CharStr's. +//--- Interaction + +impl Txt { + /// Iterate over the [`CharStr`]s in this record. + pub fn iter<'a>( + &'a self, + ) -> impl Iterator> + 'a { + // NOTE: A TXT record always has at least one 'CharStr' within. + let first = <&CharStr>::split_from(&self.content); + core::iter::successors(Some(first), |prev| { + prev.as_ref() + .ok() + .map(|(_elem, rest)| <&CharStr>::split_from(rest)) + }) + .map(|result| result.map(|(elem, _rest)| elem)) + } +} //--- Parsing from DNS messages @@ -639,3 +655,26 @@ impl BuildInto for Txt { self.content.build_into(bytes) } } + +//--- Formatting + +impl fmt::Debug for Txt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Content<'a>(&'a Txt); + impl fmt::Debug for Content<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + for elem in self.0.iter() { + if let Ok(elem) = elem { + list.entry(&elem); + } else { + list.entry(&ParseError); + } + } + list.finish() + } + } + + f.debug_tuple("Txt").field(&Content(self)).finish() + } +} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 1aad4cca0..8fb32032f 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,7 @@ //! Record data types. mod basic; -pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; +pub use basic::{CName, HInfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; mod ipv6; pub use ipv6::Aaaa; From 6d358a3e81ef08142511fed1a1f462fee5eeed98 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:53:36 +0100 Subject: [PATCH 047/111] [new_rdata] Define enum 'RecordData' --- src/new_base/record.rs | 34 ++++++++ src/new_rdata/mod.rs | 176 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 93038f8d9..6354611ed 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -240,6 +240,40 @@ pub struct RType { pub code: U16, } +//--- Associated Constants + +impl RType { + /// The type of an [`A`](crate::new_rdata::A) record. + pub const A: Self = Self { code: U16::new(1) }; + + /// The type of an [`Ns`](crate::new_rdata::Ns) record. + pub const NS: Self = Self { code: U16::new(2) }; + + /// The type of a [`CName`](crate::new_rdata::CName) record. + pub const CNAME: Self = Self { code: U16::new(5) }; + + /// The type of an [`Soa`](crate::new_rdata::Soa) record. + pub const SOA: Self = Self { code: U16::new(6) }; + + /// The type of a [`Wks`](crate::new_rdata::Wks) record. + pub const WKS: Self = Self { code: U16::new(11) }; + + /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. + pub const PTR: Self = Self { code: U16::new(12) }; + + /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. + pub const HINFO: Self = Self { code: U16::new(13) }; + + /// The type of a [`Mx`](crate::new_rdata::Mx) record. + pub const MX: Self = Self { code: U16::new(15) }; + + /// The type of a [`Txt`](crate::new_rdata::Txt) record. + pub const TXT: Self = Self { code: U16::new(16) }; + + /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. + pub const AAAA: Self = Self { code: U16::new(28) }; +} + //----------- RClass --------------------------------------------------------- /// The class of a record. diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 8fb32032f..67f4d9cb3 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,5 +1,19 @@ //! Record data types. +use core::ops::Range; + +use zerocopy_derive::*; + +use crate::new_base::{ + build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, ParseRecordData, RType, +}; + +//----------- Concrete record data types ------------------------------------- + mod basic; pub use basic::{CName, HInfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; @@ -8,3 +22,165 @@ pub use ipv6::Aaaa; mod edns; pub use edns::Opt; + +//----------- RecordData ----------------------------------------------------- + +/// DNS record data. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum RecordData<'a, N> { + /// The IPv4 address of a host responsible for this domain. + A(&'a A), + + /// The authoritative name server for this domain. + Ns(Ns), + + /// The canonical name for this domain. + CName(CName), + + /// The start of a zone of authority. + Soa(Soa), + + /// Well-known services supported on this domain. + Wks(&'a Wks), + + /// A pointer to another domain name. + Ptr(Ptr), + + /// Information about the host computer. + HInfo(HInfo<'a>), + + /// A host that can exchange mail for this domain. + Mx(Mx), + + /// Free-form text strings about this domain. + Txt(&'a Txt), + + /// The IPv6 address of a host responsible for this domain. + Aaaa(&'a Aaaa), + + /// Data for an unknown DNS record type. + Unknown(RType, &'a UnknownRecordData), +} + +//--- Parsing record data + +impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> +where + N: SplitFrom<'a> + SplitFromMessage<'a>, +{ + fn parse_record_data( + message: &'a Message, + range: Range, + rtype: RType, + ) -> Result { + match rtype { + RType::A => <&A>::parse_from_message(message, range).map(Self::A), + RType::NS => Ns::parse_from_message(message, range).map(Self::Ns), + RType::CNAME => { + CName::parse_from_message(message, range).map(Self::CName) + } + RType::SOA => { + Soa::parse_from_message(message, range).map(Self::Soa) + } + RType::WKS => { + <&Wks>::parse_from_message(message, range).map(Self::Wks) + } + RType::PTR => { + Ptr::parse_from_message(message, range).map(Self::Ptr) + } + RType::HINFO => { + HInfo::parse_from_message(message, range).map(Self::HInfo) + } + RType::MX => Mx::parse_from_message(message, range).map(Self::Mx), + RType::TXT => { + <&Txt>::parse_from_message(message, range).map(Self::Txt) + } + RType::AAAA => { + <&Aaaa>::parse_from_message(message, range).map(Self::Aaaa) + } + _ => <&UnknownRecordData>::parse_from_message(message, range) + .map(|data| Self::Unknown(rtype, data)), + } + } + + fn parse_record_data_bytes( + bytes: &'a [u8], + rtype: RType, + ) -> Result { + match rtype { + RType::A => <&A>::parse_from(bytes).map(Self::A), + RType::NS => Ns::parse_from(bytes).map(Self::Ns), + RType::CNAME => CName::parse_from(bytes).map(Self::CName), + RType::SOA => Soa::parse_from(bytes).map(Self::Soa), + RType::WKS => <&Wks>::parse_from(bytes).map(Self::Wks), + RType::PTR => Ptr::parse_from(bytes).map(Self::Ptr), + RType::HINFO => HInfo::parse_from(bytes).map(Self::HInfo), + RType::MX => Mx::parse_from(bytes).map(Self::Mx), + RType::TXT => <&Txt>::parse_from(bytes).map(Self::Txt), + RType::AAAA => <&Aaaa>::parse_from(bytes).map(Self::Aaaa), + _ => <&UnknownRecordData>::parse_from(bytes) + .map(|data| Self::Unknown(rtype, data)), + } + } +} + +//--- Building record data + +impl<'a, N> BuildIntoMessage for RecordData<'a, N> +where + N: BuildIntoMessage, +{ + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError> { + match self { + Self::A(r) => r.build_into_message(builder), + Self::Ns(r) => r.build_into_message(builder), + Self::CName(r) => r.build_into_message(builder), + Self::Soa(r) => r.build_into_message(builder), + Self::Wks(r) => r.build_into_message(builder), + Self::Ptr(r) => r.build_into_message(builder), + Self::HInfo(r) => r.build_into_message(builder), + Self::Txt(r) => r.build_into_message(builder), + Self::Aaaa(r) => r.build_into_message(builder), + Self::Mx(r) => r.build_into_message(builder), + Self::Unknown(_, r) => r.octets.build_into_message(builder), + } + } +} + +impl<'a, N> BuildInto for RecordData<'a, N> +where + N: BuildInto, +{ + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + match self { + Self::A(r) => r.build_into(bytes), + Self::Ns(r) => r.build_into(bytes), + Self::CName(r) => r.build_into(bytes), + Self::Soa(r) => r.build_into(bytes), + Self::Wks(r) => r.build_into(bytes), + Self::Ptr(r) => r.build_into(bytes), + Self::HInfo(r) => r.build_into(bytes), + Self::Txt(r) => r.build_into(bytes), + Self::Aaaa(r) => r.build_into(bytes), + Self::Mx(r) => r.build_into(bytes), + Self::Unknown(_, r) => r.octets.build_into(bytes), + } + } +} + +//----------- UnknownRecordData ---------------------------------------------- + +/// Data for an unknown DNS record type. +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C)] +pub struct UnknownRecordData { + /// The unparsed option data. + pub octets: [u8], +} From bd08a473ff7f1fd1983769a7cf1d3150ab6d638b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 31 Dec 2024 00:00:30 +0100 Subject: [PATCH 048/111] [new_rdata/basic] Elide lifetime as per Clippy --- src/new_rdata/basic.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index cf6a99c39..4807784b3 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -590,9 +590,9 @@ pub struct Txt { impl Txt { /// Iterate over the [`CharStr`]s in this record. - pub fn iter<'a>( - &'a self, - ) -> impl Iterator> + 'a { + pub fn iter( + &self, + ) -> impl Iterator> + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. let first = <&CharStr>::split_from(&self.content); core::iter::successors(Some(first), |prev| { From 54f131effb1a84457d52d4bcc91c840f959fd19a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 31 Dec 2024 11:56:00 +0100 Subject: [PATCH 049/111] [new_rdata] Elide lifetimes as per clippy --- src/new_rdata/mod.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 67f4d9cb3..3228608cc 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -127,10 +127,7 @@ where //--- Building record data -impl<'a, N> BuildIntoMessage for RecordData<'a, N> -where - N: BuildIntoMessage, -{ +impl BuildIntoMessage for RecordData<'_, N> { fn build_into_message( &self, builder: Builder<'_>, @@ -151,10 +148,7 @@ where } } -impl<'a, N> BuildInto for RecordData<'a, N> -where - N: BuildInto, -{ +impl BuildInto for RecordData<'_, N> { fn build_into<'b>( &self, bytes: &'b mut [u8], From 6881f6ad4b0de73acf651ce0ab281a673cc177c9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 11:35:09 +0100 Subject: [PATCH 050/111] Set up a 'domain-macros' crate 'domain-macros' will provide: - 'derive' macros for DNS-specific (zero-copy) serialization - 'derive' macros for building and parsing specialized DNS messages - 'derive' macros for composing clients and servers The first use case will replace 'zerocopy'. --- Cargo.lock | 10 ++++++++++ Cargo.toml | 10 ++++++---- macros/Cargo.toml | 23 +++++++++++++++++++++++ macros/src/lib.rs | 3 +++ 4 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 macros/Cargo.toml create mode 100644 macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 7506702e2..d9833efa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,6 +245,7 @@ dependencies = [ "arc-swap", "bytes", "chrono", + "domain-macros", "futures-util", "hashbrown 0.14.5", "heapless", @@ -282,6 +283,15 @@ dependencies = [ "zerocopy-derive 0.8.13", ] +[[package]] +name = "domain-macros" +version = "0.10.3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 9d078a5e5..041d83731 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ +[workspace] +resolver = "2" +members = [".", "./macros"] + [package] name = "domain" version = "0.10.3" @@ -16,11 +20,9 @@ readme = "README.md" keywords = ["DNS", "domain"] license = "BSD-3-Clause" -[lib] -name = "domain" -path = "src/lib.rs" - [dependencies] +domain-macros = { path = "./macros", version = "0.10.3" } + arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 000000000..b94e60bbd --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "domain-macros" + +# Copied from 'domain'. +version = "0.10.3" +rust-version = "1.68.2" +edition = "2021" + +authors = ["NLnet Labs "] +description = "Procedural macros for the `domain` crate." +documentation = "https://docs.rs/domain-macros" +homepage = "https://github.com/nlnetlabs/domain/" +repository = "https://github.com/nlnetlabs/domain/" +keywords = ["DNS", "domain"] +license = "BSD-3-Clause" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +syn = "2.0" +quote = "1.0" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 000000000..0bf081d73 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,3 @@ +//! Procedural macros for [`domain`]. +//! +//! [`domain`]: https://docs.rs/domain From 155fb5acabbde4a0258f9e0a2377e8e0fee76169 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 12:32:37 +0100 Subject: [PATCH 051/111] [new_base/parse] Define 'ParseBytesByRef' for deriving --- src/new_base/parse/mod.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 31ad191ba..920395bd9 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -97,6 +97,33 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +/// Zero-copy parsing from a byte string. +/// +/// # Safety +/// +/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants +/// documented on [`parse_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// Implementing types should almost always be unaligned, but foregoing this +/// will not cause undefined behaviour (however, it will be very confusing for +/// users). +pub unsafe trait ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; +} + //--- Carrying over 'zerocopy' traits // NOTE: We can't carry over 'read_from_prefix' because the trait impls would From cf170e6af2f433fcd9d6ddfc8de5ab6212f811e4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 19:25:00 +0100 Subject: [PATCH 052/111] [macros] Define 'ImplSkeleton' and prepare a basic derive --- macros/Cargo.toml | 13 ++- macros/src/impls.rs | 237 ++++++++++++++++++++++++++++++++++++++++++++ macros/src/lib.rs | 52 ++++++++++ src/lib.rs | 8 ++ 4 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 macros/src/impls.rs diff --git a/macros/Cargo.toml b/macros/Cargo.toml index b94e60bbd..263db27af 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -17,7 +17,12 @@ license = "BSD-3-Clause" [lib] proc-macro = true -[dependencies] -proc-macro2 = "1.0" -syn = "2.0" -quote = "1.0" +[dependencies.proc-macro2] +version = "1.0" + +[dependencies.syn] +version = "2.0" +features = ["visit"] + +[dependencies.quote] +version = "1.0" diff --git a/macros/src/impls.rs b/macros/src/impls.rs new file mode 100644 index 000000000..c5af737fb --- /dev/null +++ b/macros/src/impls.rs @@ -0,0 +1,237 @@ +//! Helpers for generating `impl` blocks. + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{punctuated::Punctuated, visit::Visit, *}; + +//----------- ImplSkeleton --------------------------------------------------- + +/// The skeleton of an `impl` block. +pub struct ImplSkeleton { + /// Lifetime parameters for the `impl` block. + pub lifetimes: Vec, + + /// Type parameters for the `impl` block. + pub types: Vec, + + /// Const generic parameters for the `impl` block. + pub consts: Vec, + + /// Whether the `impl` is unsafe. + pub unsafety: Option, + + /// The trait being implemented. + pub bound: Path, + + /// The type being implemented on. + pub subject: Path, + + /// The where clause of the `impl` block. + pub where_clause: WhereClause, + + /// The contents of the `impl`. + pub contents: Block, + + /// A `const` block for asserting requirements. + pub requirements: Block, +} + +impl ImplSkeleton { + /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. + pub fn new(input: &DeriveInput, unsafety: bool, bound: Path) -> Self { + let mut lifetimes = Vec::new(); + let mut types = Vec::new(); + let mut consts = Vec::new(); + let mut subject_args = Punctuated::new(); + + for param in &input.generics.params { + match param { + GenericParam::Lifetime(value) => { + lifetimes.push(value.clone()); + let id = value.lifetime.clone(); + subject_args.push(GenericArgument::Lifetime(id)); + } + + GenericParam::Type(value) => { + types.push(value.clone()); + let id = value.ident.clone(); + let id = TypePath { + qself: None, + path: Path { + leading_colon: None, + segments: [PathSegment { + ident: id, + arguments: PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + + GenericParam::Const(value) => { + consts.push(value.clone()); + let id = value.ident.clone(); + let id = TypePath { + qself: None, + path: Path { + leading_colon: None, + segments: [PathSegment { + ident: id, + arguments: PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + } + } + + let unsafety = unsafety.then_some(::default()); + + let subject = Path { + leading_colon: None, + segments: [PathSegment { + ident: input.ident.clone(), + arguments: PathArguments::AngleBracketed( + AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Default::default(), + args: subject_args, + gt_token: Default::default(), + }, + ), + }] + .into_iter() + .collect(), + }; + + let where_clause = + input.generics.where_clause.clone().unwrap_or(WhereClause { + where_token: Default::default(), + predicates: Punctuated::new(), + }); + + let contents = Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + let requirements = Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + Self { + lifetimes, + types, + consts, + unsafety, + bound, + subject, + where_clause, + contents, + requirements, + } + } + + /// Require a bound for a type. + /// + /// If the type is concrete, a verifying statement is added for it. + /// Otherwise, it is added to the where clause. + pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { + if self.is_concrete(&target) { + // Add a concrete requirement for this bound. + self.requirements.stmts.push(parse_quote! { + const _: fn() = || { + fn assert_impl() {} + assert_impl::<#target>(); + }; + }); + } else { + // Add this bound to the `where` clause. + let mut bounds = Punctuated::new(); + bounds.push_value(bound); + let pred = WherePredicate::Type(PredicateType { + lifetimes: None, + bounded_ty: target, + colon_token: Default::default(), + bounds, + }); + self.where_clause.predicates.push_value(pred); + } + } + + /// Whether a type is concrete within this `impl` block. + pub fn is_concrete(&self, target: &Type) -> bool { + struct ConcretenessVisitor<'a> { + /// The `impl` skeleton being added to. + skeleton: &'a ImplSkeleton, + + /// Whether the visited type is concrete. + is_concrete: bool, + } + + impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { + fn visit_lifetime(&mut self, i: &'ast Lifetime) { + self.is_concrete = self.is_concrete + && self + .skeleton + .lifetimes + .iter() + .all(|l| l.lifetime != *i); + } + + fn visit_ident(&mut self, i: &'ast Ident) { + self.is_concrete = self.is_concrete + && self.skeleton.types.iter().all(|t| t.ident != *i); + self.is_concrete = self.is_concrete + && self.skeleton.consts.iter().all(|c| c.ident != *i); + } + } + + let mut visitor = ConcretenessVisitor { + skeleton: self, + is_concrete: true, + }; + + visitor.visit_type(target); + + visitor.is_concrete + } +} + +impl ToTokens for ImplSkeleton { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { + lifetimes, + types, + consts, + unsafety, + bound, + subject, + where_clause, + contents, + requirements, + } = self; + + quote! { + #unsafety + impl<#(#lifetimes,)* #(#types,)* #(#consts,)*> + #bound for #subject + #where_clause + #contents + } + .to_tokens(tokens); + + if !requirements.stmts.is_empty() { + quote! { + const _: () = #requirements; + } + .to_tokens(tokens); + } + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 0bf081d73..68f956adc 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,3 +1,55 @@ //! Procedural macros for [`domain`]. //! //! [`domain`]: https://docs.rs/domain + +use proc_macro as pm; +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::*; + +mod impls; +use impls::ImplSkeleton; + +//----------- ParseBytesByRef ------------------------------------------------ + +#[proc_macro_derive(ParseBytesByRef)] +pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + let data = match input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. + + // Every field must implement 'ParseBytesByRef'. + for field in data.fields.iter() { + let bound = + parse_quote!(::domain::new_base::parse::ParseBytesByRef); + skeleton.require_bound(field.ty.clone(), bound); + } + + // TODO: Implement 'parse_bytes_by_ref()' in 'skeleton.contents'. + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/src/lib.rs b/src/lib.rs index 4d08972f0..6fe1aeec9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,6 +190,14 @@ extern crate std; #[macro_use] extern crate core; +// The 'domain-macros' crate introduces 'derive' macros which can be used by +// users of the 'domain' crate, but also by the 'domain' crate itself. Within +// those macros, references to declarations in the 'domain' crate are written +// as '::domain::*' ... but this doesn't work when those proc macros are used +// by the 'domain' crate itself. The alias introduced here fixes this: now +// '::domain' means the same thing within this crate as in dependents of it. +extern crate self as domain; + pub mod base; pub mod dep; pub mod net; From 89ee79785dd273a356c8f180592bc7308f0269d8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 12:58:04 +0100 Subject: [PATCH 053/111] Expand 'ParseBytesByRef' and largely finish its derive macro --- macros/src/lib.rs | 104 +++++++++++++++++++++++--- src/lib.rs | 9 ++- src/new_base/parse/mod.rs | 153 +++++++++++++++++++++++++++++++++++++- 3 files changed, 249 insertions(+), 17 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 68f956adc..255567046 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -4,7 +4,8 @@ use proc_macro as pm; use proc_macro2::TokenStream; -use quote::ToTokens; +use quote::{quote, ToTokens}; +use spanned::Spanned; use syn::*; mod impls; @@ -15,10 +16,7 @@ use impls::ImplSkeleton; #[proc_macro_derive(ParseBytesByRef)] pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { fn inner(input: DeriveInput) -> Result { - let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); - - let data = match input.data { + let data = match &input.data { Data::Struct(data) => data, Data::Enum(data) => { return Err(Error::new_spanned( @@ -36,14 +34,98 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. - // Every field must implement 'ParseBytesByRef'. - for field in data.fields.iter() { - let bound = - parse_quote!(::domain::new_base::parse::ParseBytesByRef); - skeleton.require_bound(field.ty.clone(), bound); + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); + let name = input.ident; + + return Ok(quote! { + impl #impl_generics + ::domain::new_base::parse::ParseBytesByRef + for #name #ty_generics + #where_clause { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } + + fn ptr_with_address( + &self, + addr: *const (), + ) -> *const Self { + addr.cast() + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + // Establish bounds on the fields. + for field in fields.clone() { + // This field should implement 'SplitBytesByRef'. + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytesByRef), + ); } + // The last field should implement 'ParseBytesByRef'. + skeleton.require_bound( + last.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytesByRef), + ); - // TODO: Implement 'parse_bytes_by_ref()' in 'skeleton.contents'. + // Define 'parse_bytes_by_ref()'. + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let last = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::parse_bytes_by_ref(bytes)?; + let ptr = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - By + Ok(unsafe { &*(ptr as *const Self) }) + } + }); + + // Define 'ptr_with_address()'. + let last_name = match last.ident.as_ref() { + Some(ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { + index: data.fields.len() as u32 - 1, + span: last.ty.span(), + }), + }; + skeleton.contents.stmts.push(parse_quote! { + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(&self.#last_name, addr) + as *const Self + } + }); Ok(skeleton.into_token_stream().into()) } diff --git a/src/lib.rs b/src/lib.rs index 6fe1aeec9..40b4efd7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -187,17 +187,18 @@ #[macro_use] extern crate std; -#[macro_use] -extern crate core; - // The 'domain-macros' crate introduces 'derive' macros which can be used by // users of the 'domain' crate, but also by the 'domain' crate itself. Within // those macros, references to declarations in the 'domain' crate are written // as '::domain::*' ... but this doesn't work when those proc macros are used -// by the 'domain' crate itself. The alias introduced here fixes this: now +// in the 'domain' crate itself. The alias introduced here fixes this: now // '::domain' means the same thing within this crate as in dependents of it. extern crate self as domain; +// Re-export 'core' for use in macros. +#[doc(hidden)] +pub use core as __core; + pub mod base; pub mod dep; pub mod net; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 920395bd9..b218646c9 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -97,13 +97,47 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +/// Zero-copy parsing from the start of a byte string. +/// +/// # Safety +/// +/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants +/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// Implementing types should almost always be unaligned, but foregoing this +/// will not cause undefined behaviour (however, it will be very confusing for +/// users). +pub unsafe trait SplitBytesByRef: ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_ref(bytes: &[u8]) + -> Result<(&Self, &[u8]), ParseError>; +} + /// Zero-copy parsing from a byte string. /// /// # Safety /// /// Every implementation of [`ParseBytesByRef`] must satisfy the invariants -/// documented on [`parse_bytes_by_ref()`]. An incorrect implementation is -/// considered to cause undefined behaviour. +/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An +/// incorrect implementation is considered to cause undefined behaviour. +/// +/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() +/// [`ptr_with_address()`]: Self::ptr_with_address() /// /// Implementing types should almost always be unaligned, but foregoing this /// will not cause undefined behaviour (however, it will be very confusing for @@ -122,6 +156,121 @@ pub unsafe trait ParseBytesByRef { /// - `bytes.as_ptr() == this as *const T as *const u8`. /// - `bytes.len() == core::mem::size_of_val(this)`. fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; + + /// Change the address of a pointer to [`Self`]. + /// + /// When [`Self`] is used as the last field in a type that also implements + /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// reference) to it may include additional metadata. This metadata is + /// included verbatim in any reference/pointer to the containing type. + /// + /// When the containing type implements [`ParseBytesByRef`], it needs to + /// construct a reference/pointer to itself, which includes this metadata. + /// Rust does not currently offer a general way to extract this metadata + /// or pair it with another address, so this function is necessary. The + /// caller can construct a reference to [`Self`], then change its address + /// to point to the containing type, then cast that pointer to the right + /// type. + /// + /// # Implementing + /// + /// Most users will derive [`ParseBytesByRef`] and so don't need to worry + /// about this. For manual implementations: + /// + /// In the future, an adequate default implementation for this function + /// may be provided. Until then, it should be implemented using one of + /// the following expressions: + /// + /// ```ignore + /// fn ptr_with_address( + /// &self, + /// addr: *const (), + /// ) -> *const Self { + /// // If 'Self' is Sized: + /// addr.cast::() + /// + /// // If 'Self' is an aggregate whose last field is 'last': + /// self.last.ptr_with_address(addr) as *const Self + /// } + /// ``` + /// + /// # Invariants + /// + /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: + /// + /// - `result as usize == addr as usize`. + /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. + fn ptr_with_address(&self, addr: *const ()) -> *const Self; +} + +unsafe impl SplitBytesByRef for u8 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + bytes.split_first().ok_or(ParseError) + } +} + +unsafe impl ParseBytesByRef for u8 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let [result] = bytes else { + return Err(ParseError); + }; + + return Ok(result); + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl ParseBytesByRef for [u8] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Ok(bytes) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + core::ptr::slice_from_raw_parts(addr.cast(), self.len()) + } +} + +unsafe impl SplitBytesByRef for [u8; N] { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + if bytes.len() < N { + Err(ParseError) + } else { + let (bytes, rest) = bytes.split_at(N); + + // SAFETY: + // - It is known that 'bytes.len() == N'. + // - Thus '&bytes' has the same layout as '[u8; N]'. + // - Thus it is safe to cast a pointer to it to '[u8; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }, rest)) + } + } +} + +unsafe impl ParseBytesByRef for [u8; N] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + if bytes.len() != N { + Err(ParseError) + } else { + // SAFETY: + // - It is known that 'bytes.len() == N'. + // - Thus '&bytes' has the same layout as '[u8; N]'. + // - Thus it is safe to cast a pointer to it to '[u8; N]'. + // - The referenced data has the same lifetime as the output. + Ok(unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } } //--- Carrying over 'zerocopy' traits From baaa8d2ad63ec90b634489aceff18a39615f7b74 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:13:11 +0100 Subject: [PATCH 054/111] [macros] Add module 'repr' for checking for stable layouts --- macros/src/lib.rs | 5 +++- macros/src/repr.rs | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 macros/src/repr.rs diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 255567046..755467061 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -11,6 +11,9 @@ use syn::*; mod impls; use impls::ImplSkeleton; +mod repr; +use repr::Repr; + //----------- ParseBytesByRef ------------------------------------------------ #[proc_macro_derive(ParseBytesByRef)] @@ -32,7 +35,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }; - // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. + let _ = Repr::determine(&input.attrs)?; // Split up the last field from the rest. let mut fields = data.fields.iter(); diff --git a/macros/src/repr.rs b/macros/src/repr.rs new file mode 100644 index 000000000..428ef2a10 --- /dev/null +++ b/macros/src/repr.rs @@ -0,0 +1,68 @@ +//! Determining the memory layout of a type. + +use proc_macro2::Span; +use syn::{punctuated::Punctuated, *}; + +//----------- Repr ----------------------------------------------------------- + +/// The memory representation of a type. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Repr { + /// Transparent to an underlying field. + Transparent, + + /// Compatible with C. + C, +} + +impl Repr { + /// Determine the representation for a type from its attributes. + /// + /// This will fail if a stable representation cannot be found. + pub fn determine(attrs: &[Attribute]) -> Result { + let mut repr = None; + for attr in attrs { + if !attr.path().is_ident("repr") { + continue; + } + + let nested = attr.parse_args_with( + Punctuated::::parse_terminated, + )?; + + // We don't check for consistency in the 'repr' attributes, since + // the compiler should be doing that for us anyway. This lets us + // ignore conflicting 'repr's entirely. + for meta in nested { + match meta { + Meta::Path(p) if p.is_ident("transparent") => { + repr = Some(Repr::Transparent); + } + + Meta::Path(p) if p.is_ident("C") => { + repr = Some(Repr::C); + } + + Meta::Path(p) if p.is_ident("Rust") => { + return Err(Error::new_spanned(p, + "repr(Rust) is not stable, cannot derive this for it")); + } + + meta => { + // We still need to error out here, in case a future + // version of Rust introduces more memory layout data + return Err(Error::new_spanned( + meta, + "unrecognized repr attribute", + )); + } + } + } + } + + repr.ok_or_else(|| { + Error::new(Span::call_site(), + "repr(C) or repr(transparent) must be specified to derive this") + }) + } +} From d4fc42f21e1c73e22fc01371456aa50a38567046 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:20:42 +0100 Subject: [PATCH 055/111] [new_base/parse] Implement '*BytesByRef' for '[T; N]' --- src/new_base/parse/mod.rs | 41 ++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index b218646c9..4c8917308 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -235,36 +235,33 @@ unsafe impl ParseBytesByRef for [u8] { } } -unsafe impl SplitBytesByRef for [u8; N] { +unsafe impl SplitBytesByRef for [T; N] { fn split_bytes_by_ref( - bytes: &[u8], + mut bytes: &[u8], ) -> Result<(&Self, &[u8]), ParseError> { - if bytes.len() < N { - Err(ParseError) - } else { - let (bytes, rest) = bytes.split_at(N); - - // SAFETY: - // - It is known that 'bytes.len() == N'. - // - Thus '&bytes' has the same layout as '[u8; N]'. - // - Thus it is safe to cast a pointer to it to '[u8; N]'. - // - The referenced data has the same lifetime as the output. - Ok((unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }, rest)) + let start = bytes.as_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_ref(bytes)?; } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) } } -unsafe impl ParseBytesByRef for [u8; N] { +unsafe impl ParseBytesByRef for [T; N] { fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - if bytes.len() != N { - Err(ParseError) + let (this, rest) = Self::split_bytes_by_ref(bytes)?; + if rest.is_empty() { + Ok(this) } else { - // SAFETY: - // - It is known that 'bytes.len() == N'. - // - Thus '&bytes' has the same layout as '[u8; N]'. - // - Thus it is safe to cast a pointer to it to '[u8; N]'. - // - The referenced data has the same lifetime as the output. - Ok(unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }) + Err(ParseError) } } From a1bfc4f1042edb4f43f3f3ae054e70fa323cf630 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:49:00 +0100 Subject: [PATCH 056/111] [macros] Add a derive macro for 'SplitBytesByRef' 'ParseBytesByRef' now requires all implementing types to be unaligned. Otherwise, padding bytes wouldn't be accounted for properly, e.g. in // Has alignment of largest field: 8 bytes #[repr(C)] pub struct Foo { a: u8, // 7 bytes of padding here b: u64, } The 'derive' can't tell how much padding to use, so it would parse a '[u8; 9]' as a valid instance of 'Foo'. Every 'repr(C)' would have to use 'repr(packed)' too. --- macros/src/lib.rs | 121 +++++++++++++++++++++++++++++++++++++- macros/src/repr.rs | 23 +++++++- src/new_base/parse/mod.rs | 6 +- 3 files changed, 141 insertions(+), 9 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 755467061..8cb26183f 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -14,6 +14,113 @@ use impls::ImplSkeleton; mod repr; use repr::Repr; +//----------- SplitBytesByRef ------------------------------------------------ + +#[proc_macro_derive(SplitBytesByRef)] +pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "SplitBytesByRef")?; + + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); + let name = input.ident; + + return Ok(quote! { + unsafe impl #impl_generics + ::domain::new_base::parse::SplitBytesByRef + for #name #ty_generics + #where_clause { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + Ok(( + unsafe { &*bytes.as_ptr().cast::() }, + bytes, + )) + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = parse_quote!(::domain::new_base::parse::SplitBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytesByRef), + ); + } + + // Define 'split_bytes_by_ref()'. + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let (last, rest) = + <#last_ty as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?; + let ptr = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &*(ptr as *const Self) }, rest)) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + //----------- ParseBytesByRef ------------------------------------------------ #[proc_macro_derive(ParseBytesByRef)] @@ -35,7 +142,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }; - let _ = Repr::determine(&input.attrs)?; + let _ = Repr::determine(&input.attrs, "ParseBytesByRef")?; // Split up the last field from the rest. let mut fields = data.fields.iter(); @@ -46,7 +153,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let name = input.ident; return Ok(quote! { - impl #impl_generics + unsafe impl #impl_generics ::domain::new_base::parse::ParseBytesByRef for #name #ty_generics #where_clause { @@ -109,7 +216,15 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { ::ptr_with_address(last, start as *const ()); // SAFETY: - // - By + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. Ok(unsafe { &*(ptr as *const Self) }) } }); diff --git a/macros/src/repr.rs b/macros/src/repr.rs index 428ef2a10..80c900eb6 100644 --- a/macros/src/repr.rs +++ b/macros/src/repr.rs @@ -1,7 +1,7 @@ //! Determining the memory layout of a type. use proc_macro2::Span; -use syn::{punctuated::Punctuated, *}; +use syn::{punctuated::Punctuated, spanned::Spanned, *}; //----------- Repr ----------------------------------------------------------- @@ -19,7 +19,7 @@ impl Repr { /// Determine the representation for a type from its attributes. /// /// This will fail if a stable representation cannot be found. - pub fn determine(attrs: &[Attribute]) -> Result { + pub fn determine(attrs: &[Attribute], bound: &str) -> Result { let mut repr = None; for attr in attrs { if !attr.path().is_ident("repr") { @@ -45,7 +45,24 @@ impl Repr { Meta::Path(p) if p.is_ident("Rust") => { return Err(Error::new_spanned(p, - "repr(Rust) is not stable, cannot derive this for it")); + format!("repr(Rust) is not stable, cannot derive {bound} for it"))); + } + + Meta::Path(p) if p.is_ident("packed") => { + // The alignment can be set to 1 safely. + } + + Meta::List(meta) + if meta.path.is_ident("packed") + || meta.path.is_ident("aligned") => + { + let span = meta.span(); + let lit: LitInt = parse2(meta.tokens)?; + let n: usize = lit.base10_parse()?; + if n != 1 { + return Err(Error::new(span, + format!("'Self' must be unaligned to derive {bound}"))); + } } meta => { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 4c8917308..eb4815d04 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -139,9 +139,9 @@ pub unsafe trait SplitBytesByRef: ParseBytesByRef { /// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() /// [`ptr_with_address()`]: Self::ptr_with_address() /// -/// Implementing types should almost always be unaligned, but foregoing this -/// will not cause undefined behaviour (however, it will be very confusing for -/// users). +/// Implementing types must also have no alignment (i.e. a valid instance of +/// [`Self`] can occur at any address). This eliminates the possibility of +/// padding bytes, even when [`Self`] is part of a larger aggregate type. pub unsafe trait ParseBytesByRef { /// Interpret a byte string as an instance of [`Self`]. /// From 2164b8757b5aef89e5cbc6980f0f0c6449cf31b8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 14:08:43 +0100 Subject: [PATCH 057/111] Use '{Parse,Split}BytesByRef' instead of 'zerocopy' --- src/new_base/parse/mod.rs | 95 +++++++++++++++++++++++---------------- src/new_base/question.rs | 11 +++-- src/new_base/record.rs | 28 ++++++------ src/new_base/serial.rs | 6 +-- src/new_edns/mod.rs | 13 +++--- src/new_rdata/basic.rs | 8 ++-- src/new_rdata/ipv6.rs | 6 +-- src/new_rdata/mod.rs | 3 +- 8 files changed, 92 insertions(+), 78 deletions(-) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index eb4815d04..c5faf0440 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -2,7 +2,10 @@ use core::{fmt, ops::Range}; -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +use zerocopy::{ + network_endian::{U16, U32}, + FromBytes, IntoBytes, +}; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -43,37 +46,26 @@ pub trait ParseFromMessage<'a>: Sized { ) -> Result; } -//--- Carrying over 'zerocopy' traits - -// NOTE: We can't carry over 'read_from_prefix' because the trait impls would -// conflict. We kept 'ref_from_prefix' since it's more general. - -impl<'a, T: ?Sized> SplitFromMessage<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ +impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { fn split_from_message( message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { let message = message.as_bytes(); let bytes = message.get(start..).ok_or(ParseError)?; - let (this, rest) = T::ref_from_prefix(bytes)?; + let (this, rest) = T::split_bytes_by_ref(bytes)?; Ok((this, message.len() - rest.len())) } } -impl<'a, T: ?Sized> ParseFromMessage<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ +impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { fn parse_from_message( message: &'a Message, range: Range, ) -> Result { let message = message.as_bytes(); let bytes = message.get(range).ok_or(ParseError)?; - Ok(T::ref_from_bytes(bytes)?) + T::parse_bytes_by_ref(bytes) } } @@ -97,6 +89,18 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +impl<'a, T: ?Sized + SplitBytesByRef> SplitFrom<'a> for &'a T { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::split_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +impl<'a, T: ?Sized + ParseBytesByRef> ParseFrom<'a> for &'a T { + fn parse_from(bytes: &'a [u8]) -> Result { + T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + /// Zero-copy parsing from the start of a byte string. /// /// # Safety @@ -225,6 +229,42 @@ unsafe impl ParseBytesByRef for u8 { } } +unsafe impl SplitBytesByRef for U16 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + Self::ref_from_prefix(bytes).map_err(Into::into) + } +} + +unsafe impl ParseBytesByRef for U16 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::ref_from_bytes(bytes).map_err(Into::into) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl SplitBytesByRef for U32 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + Self::ref_from_prefix(bytes).map_err(Into::into) + } +} + +unsafe impl ParseBytesByRef for U32 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::ref_from_bytes(bytes).map_err(Into::into) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + unsafe impl ParseBytesByRef for [u8] { fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { Ok(bytes) @@ -270,29 +310,6 @@ unsafe impl ParseBytesByRef for [T; N] { } } -//--- Carrying over 'zerocopy' traits - -// NOTE: We can't carry over 'read_from_prefix' because the trait impls would -// conflict. We kept 'ref_from_prefix' since it's more general. - -impl<'a, T: ?Sized> SplitFrom<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - T::ref_from_prefix(bytes).map_err(|_| ParseError) - } -} - -impl<'a, T: ?Sized> ParseFrom<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ - fn parse_from(bytes: &'a [u8]) -> Result { - T::ref_from_bytes(bytes).map_err(|_| ParseError) - } -} - //----------- ParseError ----------------------------------------------------- /// A DNS message parsing error. diff --git a/src/new_base/question.rs b/src/new_base/question.rs index f173a664f..81411aaf0 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,6 +2,7 @@ use core::ops::Range; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; @@ -150,11 +151,10 @@ where PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct QType { @@ -174,11 +174,10 @@ pub struct QType { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct QClass { diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 6354611ed..36d4c58dd 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,6 +5,7 @@ use core::{ ops::{Deref, Range}, }; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, FromBytes, IntoBytes, SizeError, @@ -156,9 +157,9 @@ where { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; - let (rtype, rest) = RType::read_from_prefix(rest)?; - let (rclass, rest) = RClass::read_from_prefix(rest)?; - let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (&rtype, rest) = <&RType>::split_from(rest)?; + let (&rclass, rest) = <&RClass>::split_from(rest)?; + let (&ttl, rest) = <&TTL>::split_from(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -175,9 +176,9 @@ where { fn parse_from(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_from(bytes)?; - let (rtype, rest) = RType::read_from_prefix(rest)?; - let (rclass, rest) = RClass::read_from_prefix(rest)?; - let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (&rtype, rest) = <&RType>::split_from(rest)?; + let (&rclass, rest) = <&RClass>::split_from(rest)?; + let (&ttl, rest) = <&TTL>::split_from(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; @@ -228,11 +229,10 @@ where PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct RType { @@ -286,11 +286,10 @@ impl RType { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct RClass { @@ -310,11 +309,10 @@ pub struct RClass { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct TTL { diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index fe00923c3..f351e1a46 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,6 +8,7 @@ use core::{ ops::{Add, AddAssign}, }; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::network_endian::U32; use zerocopy_derive::*; @@ -21,11 +22,10 @@ use zerocopy_derive::*; PartialEq, Eq, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct Serial(U32); diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f4d12f9c1..f15132d07 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,6 +4,7 @@ use core::{fmt, ops::Range}; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; use zerocopy_derive::*; @@ -128,11 +129,10 @@ impl<'a> ParseFrom<'a> for EdnsRecord<'a> { Clone, Default, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct EdnsFlags { @@ -207,11 +207,10 @@ pub enum EdnsOption<'b> { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct OptionCode { @@ -222,7 +221,7 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] #[repr(C)] pub struct UnknownOption { /// The unparsed option data. diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 4807784b3..1b5c0baeb 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,6 +10,7 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, IntoBytes, @@ -36,11 +37,10 @@ use crate::new_base::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct A { @@ -328,7 +328,7 @@ impl BuildInto for Soa { //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(IntoBytes, Immutable, ParseBytesByRef)] #[repr(C, packed)] pub struct Wks { /// The address of the host providing these services. diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 606486d08..fdb2aa674 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,6 +8,7 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::IntoBytes; use zerocopy_derive::*; @@ -27,11 +28,10 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct Aaaa { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 3228608cc..afc4820ae 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -2,6 +2,7 @@ use core::ops::Range; +use domain_macros::ParseBytesByRef; use zerocopy_derive::*; use crate::new_base::{ @@ -172,7 +173,7 @@ impl BuildInto for RecordData<'_, N> { //----------- UnknownRecordData ---------------------------------------------- /// Data for an unknown DNS record type. -#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] #[repr(C)] pub struct UnknownRecordData { /// The unparsed option data. From 499e858929c965fec60b233f068ab2a0a6ee08b8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:37:30 +0100 Subject: [PATCH 058/111] Rename '{Split,Parse}From' to '{Split,Parse}Bytes' --- src/new_base/charstr.rs | 15 ++++---- src/new_base/name/label.rs | 14 +++---- src/new_base/name/reversed.rs | 10 ++--- src/new_base/parse/mod.rs | 68 +++++++++++++++++++++++++++------ src/new_base/question.rs | 26 ++++++------- src/new_base/record.rs | 31 +++++++-------- src/new_edns/mod.rs | 34 ++++++++--------- src/new_rdata/basic.rs | 71 ++++++++++++++++++----------------- src/new_rdata/mod.rs | 27 ++++++------- 9 files changed, 172 insertions(+), 124 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index fdd5e5bdf..57f888c27 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -8,7 +8,8 @@ use zerocopy_derive::*; use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -31,7 +32,7 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { start: usize, ) -> Result<(Self, usize), ParseError> { let bytes = &message.as_bytes()[start..]; - let (this, rest) = Self::split_from(bytes)?; + let (this, rest) = Self::split_bytes(bytes)?; Ok((this, bytes.len() - rest.len())) } } @@ -45,7 +46,7 @@ impl<'a> ParseFromMessage<'a> for &'a CharStr { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -65,8 +66,8 @@ impl BuildIntoMessage for CharStr { //--- Parsing from bytes -impl<'a> SplitFrom<'a> for &'a CharStr { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for &'a CharStr { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&length, rest) = bytes.split_first().ok_or(ParseError)?; if length as usize > rest.len() { return Err(ParseError); @@ -78,8 +79,8 @@ impl<'a> SplitFrom<'a> for &'a CharStr { } } -impl<'a> ParseFrom<'a> for &'a CharStr { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for &'a CharStr { + fn parse_bytes(bytes: &'a [u8]) -> Result { let (&length, rest) = bytes.split_first().ok_or(ParseError)?; if length as usize != rest.len() { return Err(ParseError); diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 087692047..7068e2e15 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use zerocopy_derive::*; -use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; +use crate::new_base::parse::{ParseError, ParseBytes, SplitBytes}; //----------- Label ---------------------------------------------------------- @@ -52,8 +52,8 @@ impl Label { //--- Parsing -impl<'a> SplitFrom<'a> for &'a Label { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for &'a Label { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&size, rest) = bytes.split_first().ok_or(ParseError)?; if size < 64 && rest.len() >= size as usize { let (label, rest) = bytes.split_at(1 + size as usize); @@ -65,9 +65,9 @@ impl<'a> SplitFrom<'a> for &'a Label { } } -impl<'a> ParseFrom<'a> for &'a Label { - fn parse_from(bytes: &'a [u8]) -> Result { - Self::split_from(bytes).and_then(|(this, rest)| { +impl<'a> ParseBytes<'a> for &'a Label { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::split_bytes(bytes).and_then(|(this, rest)| { rest.is_empty().then_some(this).ok_or(ParseError) }) } @@ -254,7 +254,7 @@ impl<'a> Iterator for LabelIter<'a> { // SAFETY: 'bytes' is assumed to only contain valid labels. let (head, tail) = - unsafe { <&Label>::split_from(self.bytes).unwrap_unchecked() }; + unsafe { <&Label>::split_bytes(self.bytes).unwrap_unchecked() }; self.bytes = tail; Some(head) } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 513a72582..ee7b73b9e 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -14,7 +14,7 @@ use zerocopy_derive::*; use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, }; @@ -385,8 +385,8 @@ impl BuildIntoMessage for RevNameBuf { //--- Parsing from bytes -impl<'a> SplitFrom<'a> for RevNameBuf { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for RevNameBuf { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let mut buffer = Self::empty(); let (pointer, rest) = parse_segment(bytes, &mut buffer)?; @@ -401,8 +401,8 @@ impl<'a> SplitFrom<'a> for RevNameBuf { } } -impl<'a> ParseFrom<'a> for RevNameBuf { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for RevNameBuf { + fn parse_bytes(bytes: &'a [u8]) -> Result { let mut buffer = Self::empty(); let (pointer, rest) = parse_segment(bytes, &mut buffer)?; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index c5faf0440..493542b66 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -72,46 +72,90 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. -pub trait SplitFrom<'a>: Sized + ParseFrom<'a> { +pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { /// Parse a value of [`Self`] from the start of the byte string. /// /// If parsing is successful, the parsed value and the rest of the string /// are returned. Otherwise, a [`ParseError`] is returned. - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; } /// Parsing from a byte string. -pub trait ParseFrom<'a>: Sized { +pub trait ParseBytes<'a>: Sized { /// Parse a value of [`Self`] from the given byte string. /// /// If parsing is successful, the parsed value is returned. Otherwise, a /// [`ParseError`] is returned. - fn parse_from(bytes: &'a [u8]) -> Result; + fn parse_bytes(bytes: &'a [u8]) -> Result; } -impl<'a, T: ?Sized + SplitBytesByRef> SplitFrom<'a> for &'a T { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { T::split_bytes_by_ref(bytes).map_err(|_| ParseError) } } -impl<'a, T: ?Sized + ParseBytesByRef> ParseFrom<'a> for &'a T { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { + fn parse_bytes(bytes: &'a [u8]) -> Result { T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) } } +impl<'a> SplitBytes<'a> for u8 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) + } +} + +impl<'a> ParseBytes<'a> for u8 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let [result] = bytes else { + return Err(ParseError); + }; + + Ok(*result) + } +} + +impl<'a> SplitBytes<'a> for U16 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + Self::read_from_prefix(bytes).map_err(Into::into) + } +} + +impl<'a> ParseBytes<'a> for U16 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::read_from_bytes(bytes).map_err(Into::into) + } +} + +impl<'a> SplitBytes<'a> for U32 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + Self::read_from_prefix(bytes).map_err(Into::into) + } +} + +impl<'a> ParseBytes<'a> for U32 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::read_from_bytes(bytes).map_err(Into::into) + } +} + /// Zero-copy parsing from the start of a byte string. /// +/// This is an extension of [`ParseBytesByRef`] for types which can determine +/// their own length when parsing. It is usually implemented by [`Sized`] +/// types (where the length is just the size of the type), although it can be +/// sometimes implemented by unsized types. +/// /// # Safety /// /// Every implementation of [`SplitBytesByRef`] must satisfy the invariants /// documented on [`split_bytes_by_ref()`]. An incorrect implementation is /// considered to cause undefined behaviour. /// -/// Implementing types should almost always be unaligned, but foregoing this -/// will not cause undefined behaviour (however, it will be very confusing for -/// users). +/// Note that [`ParseBytesByRef`], required by this trait, also has several +/// invariants that need to be considered with care. pub unsafe trait SplitBytesByRef: ParseBytesByRef { /// Interpret a byte string as an instance of [`Self`]. /// @@ -221,7 +265,7 @@ unsafe impl ParseBytesByRef for u8 { return Err(ParseError); }; - return Ok(result); + Ok(result) } fn ptr_with_address(&self, addr: *const ()) -> *const Self { diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 81411aaf0..029f2839f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -10,7 +10,7 @@ use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, }; @@ -98,26 +98,26 @@ where //--- Parsing from bytes -impl<'a, N> SplitFrom<'a> for Question +impl<'a, N> SplitBytes<'a> for Question where - N: SplitFrom<'a>, + N: SplitBytes<'a>, { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = N::split_from(bytes)?; - let (&qtype, rest) = <&QType>::split_from(rest)?; - let (&qclass, rest) = <&QClass>::split_from(rest)?; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (qname, rest) = N::split_bytes(bytes)?; + let (&qtype, rest) = <&QType>::split_bytes(rest)?; + let (&qclass, rest) = <&QClass>::split_bytes(rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a, N> ParseFrom<'a> for Question +impl<'a, N> ParseBytes<'a> for Question where - N: SplitFrom<'a>, + N: SplitBytes<'a>, { - fn parse_from(bytes: &'a [u8]) -> Result { - let (qname, rest) = N::split_from(bytes)?; - let (&qtype, rest) = <&QType>::split_from(rest)?; - let &qclass = <&QClass>::parse_from(rest)?; + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (qname, rest) = N::split_bytes(bytes)?; + let (&qtype, rest) = <&QType>::split_bytes(rest)?; + let &qclass = <&QClass>::parse_bytes(rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 36d4c58dd..0b3bab85b 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -16,7 +16,8 @@ use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -150,16 +151,16 @@ where //--- Parsing from bytes -impl<'a, N, D> SplitFrom<'a> for Record +impl<'a, N, D> SplitBytes<'a> for Record where - N: SplitFrom<'a>, + N: SplitBytes<'a>, D: ParseRecordData<'a>, { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (rname, rest) = N::split_from(bytes)?; - let (&rtype, rest) = <&RType>::split_from(rest)?; - let (&rclass, rest) = <&RClass>::split_from(rest)?; - let (&ttl, rest) = <&TTL>::split_from(rest)?; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (rname, rest) = N::split_bytes(bytes)?; + let (&rtype, rest) = <&RType>::split_bytes(rest)?; + let (&rclass, rest) = <&RClass>::split_bytes(rest)?; + let (&ttl, rest) = <&TTL>::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -169,16 +170,16 @@ where } } -impl<'a, N, D> ParseFrom<'a> for Record +impl<'a, N, D> ParseBytes<'a> for Record where - N: SplitFrom<'a>, + N: SplitBytes<'a>, D: ParseRecordData<'a>, { - fn parse_from(bytes: &'a [u8]) -> Result { - let (rname, rest) = N::split_from(bytes)?; - let (&rtype, rest) = <&RType>::split_from(rest)?; - let (&rclass, rest) = <&RClass>::split_from(rest)?; - let (&ttl, rest) = <&TTL>::split_from(rest)?; + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (rname, rest) = N::split_bytes(bytes)?; + let (&rtype, rest) = <&RType>::split_bytes(rest)?; + let (&rclass, rest) = <&RClass>::split_bytes(rest)?; + let (&ttl, rest) = <&TTL>::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f15132d07..8f9c7de65 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -11,7 +11,7 @@ use zerocopy_derive::*; use crate::{ new_base::{ parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, @@ -48,7 +48,7 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { start: usize, ) -> Result<(Self, usize), ParseError> { let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_from(bytes)?; + let (this, rest) = Self::split_bytes(bytes)?; Ok((this, message.as_bytes().len() - rest.len())) } } @@ -62,24 +62,24 @@ impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } //--- Parsing from bytes -impl<'a> SplitFrom<'a> for EdnsRecord<'a> { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for EdnsRecord<'a> { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { // Strip the record name (root) and the record type. let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; - let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; - let (&ext_rcode, rest) = <&u8>::split_from(rest)?; - let (&version, rest) = <&u8>::split_from(rest)?; - let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + let (&max_udp_payload, rest) = <&U16>::split_bytes(rest)?; + let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; + let (&version, rest) = <&u8>::split_bytes(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; // Split the record size and data. - let (&size, rest) = <&U16>::split_from(rest)?; + let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; @@ -96,18 +96,18 @@ impl<'a> SplitFrom<'a> for EdnsRecord<'a> { } } -impl<'a> ParseFrom<'a> for EdnsRecord<'a> { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for EdnsRecord<'a> { + fn parse_bytes(bytes: &'a [u8]) -> Result { // Strip the record name (root) and the record type. let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; - let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; - let (&ext_rcode, rest) = <&u8>::split_from(rest)?; - let (&version, rest) = <&u8>::split_from(rest)?; - let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + let (&max_udp_payload, rest) = <&U16>::split_bytes(rest)?; + let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; + let (&version, rest) = <&u8>::split_bytes(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; // Split the record size and data. - let (&size, rest) = <&U16>::split_from(rest)?; + let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); let options = Opt::ref_from_bytes_with_elems(rest, size)?; diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 1b5c0baeb..bfb11b9de 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -20,7 +20,8 @@ use zerocopy_derive::*; use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, CharStr, Message, Serial, }; @@ -142,9 +143,9 @@ impl BuildIntoMessage for Ns { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ns { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -193,9 +194,9 @@ impl BuildIntoMessage for CName { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for CName { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for CName { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -285,15 +286,15 @@ impl BuildIntoMessage for Soa { //--- Parsing from bytes -impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { - fn parse_from(bytes: &'a [u8]) -> Result { - let (mname, rest) = N::split_from(bytes)?; - let (rname, rest) = N::split_from(rest)?; - let (&serial, rest) = <&Serial>::split_from(rest)?; - let (&refresh, rest) = <&U32>::split_from(rest)?; - let (&retry, rest) = <&U32>::split_from(rest)?; - let (&expire, rest) = <&U32>::split_from(rest)?; - let &minimum = <&U32>::parse_from(rest)?; +impl<'a, N: SplitBytes<'a>> ParseBytes<'a> for Soa { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (mname, rest) = N::split_bytes(bytes)?; + let (rname, rest) = N::split_bytes(rest)?; + let (&serial, rest) = <&Serial>::split_bytes(rest)?; + let (&refresh, rest) = <&U32>::split_bytes(rest)?; + let (&retry, rest) = <&U32>::split_bytes(rest)?; + let (&expire, rest) = <&U32>::split_bytes(rest)?; + let &minimum = <&U32>::parse_bytes(rest)?; Ok(Self { mname, @@ -425,9 +426,9 @@ impl BuildIntoMessage for Ptr { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ptr { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -465,7 +466,7 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -485,10 +486,10 @@ impl BuildIntoMessage for HInfo<'_> { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for HInfo<'a> { - fn parse_from(bytes: &'a [u8]) -> Result { - let (cpu, rest) = <&CharStr>::split_from(bytes)?; - let os = <&CharStr>::parse_from(rest)?; +impl<'a> ParseBytes<'a> for HInfo<'a> { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (cpu, rest) = <&CharStr>::split_bytes(bytes)?; + let os = <&CharStr>::parse_bytes(rest)?; Ok(Self { cpu, os }) } } @@ -552,10 +553,10 @@ impl BuildIntoMessage for Mx { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { - fn parse_from(bytes: &'a [u8]) -> Result { - let (&preference, rest) = <&U16>::split_from(bytes)?; - let exchange = N::parse_from(rest)?; +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Mx { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (&preference, rest) = <&U16>::split_bytes(bytes)?; + let exchange = N::parse_bytes(rest)?; Ok(Self { preference, exchange, @@ -594,11 +595,11 @@ impl Txt { &self, ) -> impl Iterator> + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. - let first = <&CharStr>::split_from(&self.content); + let first = <&CharStr>::split_bytes(&self.content); core::iter::successors(Some(first), |prev| { prev.as_ref() .ok() - .map(|(_elem, rest)| <&CharStr>::split_from(rest)) + .map(|(_elem, rest)| <&CharStr>::split_bytes(rest)) }) .map(|result| result.map(|(elem, _rest)| elem)) } @@ -615,7 +616,7 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -632,12 +633,12 @@ impl BuildIntoMessage for Txt { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for &'a Txt { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for &'a Txt { + fn parse_bytes(bytes: &'a [u8]) -> Result { // NOTE: The input must contain at least one 'CharStr'. - let (_, mut rest) = <&CharStr>::split_from(bytes)?; + let (_, mut rest) = <&CharStr>::split_bytes(bytes)?; while !rest.is_empty() { - (_, rest) = <&CharStr>::split_from(rest)?; + (_, rest) = <&CharStr>::split_bytes(rest)?; } // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index afc4820ae..0cf020988 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -8,7 +8,8 @@ use zerocopy_derive::*; use crate::new_base::{ build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, ParseRecordData, RType, }; @@ -68,7 +69,7 @@ pub enum RecordData<'a, N> { impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> where - N: SplitFrom<'a> + SplitFromMessage<'a>, + N: SplitBytes<'a> + SplitFromMessage<'a>, { fn parse_record_data( message: &'a Message, @@ -110,17 +111,17 @@ where rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from(bytes).map(Self::A), - RType::NS => Ns::parse_from(bytes).map(Self::Ns), - RType::CNAME => CName::parse_from(bytes).map(Self::CName), - RType::SOA => Soa::parse_from(bytes).map(Self::Soa), - RType::WKS => <&Wks>::parse_from(bytes).map(Self::Wks), - RType::PTR => Ptr::parse_from(bytes).map(Self::Ptr), - RType::HINFO => HInfo::parse_from(bytes).map(Self::HInfo), - RType::MX => Mx::parse_from(bytes).map(Self::Mx), - RType::TXT => <&Txt>::parse_from(bytes).map(Self::Txt), - RType::AAAA => <&Aaaa>::parse_from(bytes).map(Self::Aaaa), - _ => <&UnknownRecordData>::parse_from(bytes) + RType::A => <&A>::parse_bytes(bytes).map(Self::A), + RType::NS => Ns::parse_bytes(bytes).map(Self::Ns), + RType::CNAME => CName::parse_bytes(bytes).map(Self::CName), + RType::SOA => Soa::parse_bytes(bytes).map(Self::Soa), + RType::WKS => <&Wks>::parse_bytes(bytes).map(Self::Wks), + RType::PTR => Ptr::parse_bytes(bytes).map(Self::Ptr), + RType::HINFO => HInfo::parse_bytes(bytes).map(Self::HInfo), + RType::MX => Mx::parse_bytes(bytes).map(Self::Mx), + RType::TXT => <&Txt>::parse_bytes(bytes).map(Self::Txt), + RType::AAAA => <&Aaaa>::parse_bytes(bytes).map(Self::Aaaa), + _ => <&UnknownRecordData>::parse_bytes(bytes) .map(|data| Self::Unknown(rtype, data)), } } From 2d60092169837f60ea32c1e0f12355389433a1b6 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:37:51 +0100 Subject: [PATCH 059/111] [macros] Add derives for '{Split,Parse}Bytes' --- macros/src/impls.rs | 77 ++++++------ macros/src/lib.rs | 282 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 315 insertions(+), 44 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index c5af737fb..e72ef91ec 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -143,7 +143,16 @@ impl ImplSkeleton { /// If the type is concrete, a verifying statement is added for it. /// Otherwise, it is added to the where clause. pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { - if self.is_concrete(&target) { + let mut visitor = ConcretenessVisitor { + skeleton: self, + is_concrete: true, + }; + + // Concreteness applies to both the type and the bound. + visitor.visit_type(&target); + visitor.visit_type_param_bound(&bound); + + if visitor.is_concrete { // Add a concrete requirement for this bound. self.requirements.stmts.push(parse_quote! { const _: fn() = || { @@ -154,54 +163,16 @@ impl ImplSkeleton { } else { // Add this bound to the `where` clause. let mut bounds = Punctuated::new(); - bounds.push_value(bound); + bounds.push(bound); let pred = WherePredicate::Type(PredicateType { lifetimes: None, bounded_ty: target, colon_token: Default::default(), bounds, }); - self.where_clause.predicates.push_value(pred); + self.where_clause.predicates.push(pred); } } - - /// Whether a type is concrete within this `impl` block. - pub fn is_concrete(&self, target: &Type) -> bool { - struct ConcretenessVisitor<'a> { - /// The `impl` skeleton being added to. - skeleton: &'a ImplSkeleton, - - /// Whether the visited type is concrete. - is_concrete: bool, - } - - impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { - fn visit_lifetime(&mut self, i: &'ast Lifetime) { - self.is_concrete = self.is_concrete - && self - .skeleton - .lifetimes - .iter() - .all(|l| l.lifetime != *i); - } - - fn visit_ident(&mut self, i: &'ast Ident) { - self.is_concrete = self.is_concrete - && self.skeleton.types.iter().all(|t| t.ident != *i); - self.is_concrete = self.is_concrete - && self.skeleton.consts.iter().all(|c| c.ident != *i); - } - } - - let mut visitor = ConcretenessVisitor { - skeleton: self, - is_concrete: true, - }; - - visitor.visit_type(target); - - visitor.is_concrete - } } impl ToTokens for ImplSkeleton { @@ -235,3 +206,27 @@ impl ToTokens for ImplSkeleton { } } } + +//----------- ConcretenessVisitor -------------------------------------------- + +struct ConcretenessVisitor<'a> { + /// The `impl` skeleton being added to. + skeleton: &'a ImplSkeleton, + + /// Whether the visited type is concrete. + is_concrete: bool, +} + +impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { + fn visit_lifetime(&mut self, i: &'ast Lifetime) { + self.is_concrete = self.is_concrete + && self.skeleton.lifetimes.iter().all(|l| l.lifetime != *i); + } + + fn visit_ident(&mut self, i: &'ast Ident) { + self.is_concrete = self.is_concrete + && self.skeleton.types.iter().all(|t| t.ident != *i); + self.is_concrete = self.is_concrete + && self.skeleton.consts.iter().all(|c| c.ident != *i); + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8cb26183f..33eb6eef5 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -3,8 +3,8 @@ //! [`domain`]: https://docs.rs/domain use proc_macro as pm; -use proc_macro2::TokenStream; -use quote::{quote, ToTokens}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; use spanned::Spanned; use syn::*; @@ -14,6 +14,275 @@ use impls::ImplSkeleton; mod repr; use repr::Repr; +//----------- SplitBytes ----------------------------------------------------- + +#[proc_macro_derive(SplitBytes)] +pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = + parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); + let mut skeleton = ImplSkeleton::new(&input, false, bound); + + // Pick a non-conflicting name for the parsing lifetime. + let lifetime = [format_ident!("bytes")] + .into_iter() + .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) + .find(|id| { + skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) + }) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap(); + + // Add the parsing lifetime to the 'impl'. + if skeleton.lifetimes.len() > 0 { + let lifetimes = skeleton.lifetimes.iter(); + let param = parse_quote! { + #lifetime: #(#lifetimes)+* + }; + skeleton.lifetimes.push(param); + } else { + skeleton.lifetimes.push(parse_quote! { #lifetime }) + } + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); + } + + // Construct a 'Self' expression. + let self_expr = match &data.fields { + Fields::Named(_) => { + let names = data.fields.members(); + let exprs = + names.clone().map(|n| format_ident!("field_{}", n)); + quote! { + Self { + #(#names: #exprs,)* + } + } + } + + Fields::Unnamed(_) => { + let exprs = data + .fields + .members() + .map(|n| format_ident!("field_{}", n)); + quote! { + Self(#(#exprs,)*) + } + } + + Fields::Unit => quote! { Self }, + }; + + // Define 'parse_bytes()'. + let names = + data.fields.members().map(|n| format_ident!("field_{}", n)); + let tys = data.fields.iter().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (Self, & #lifetime [::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + #(let (#names, bytes) = + <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + Ok((#self_expr, bytes)) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- ParseBytes ----------------------------------------------------- + +#[proc_macro_derive(ParseBytes)] +pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + assert!(input.generics.params.is_empty()); + let where_clause = input.generics.where_clause; + let name = input.ident; + + // This will tokenize to '{}', '()', or ''. + let fields = data.fields.to_token_stream(); + + return Ok(quote! { + impl <'bytes> + ::domain::new_base::parse::ParseBytes<'bytes> + for #name + #where_clause { + fn parse_bytes( + bytes: &'bytes [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + Ok(Self #fields) + } else { + Err() + } + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = + parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); + let mut skeleton = ImplSkeleton::new(&input, false, bound); + + // Pick a non-conflicting name for the parsing lifetime. + let lifetime = [format_ident!("bytes")] + .into_iter() + .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) + .find(|id| { + skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) + }) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap(); + + // Add the parsing lifetime to the 'impl'. + if skeleton.lifetimes.len() > 0 { + let lifetimes = skeleton.lifetimes.iter(); + let param = parse_quote! { + #lifetime: #(#lifetimes)+* + }; + skeleton.lifetimes.push(param); + } else { + skeleton.lifetimes.push(parse_quote! { #lifetime }) + } + + // Establish bounds on the fields. + for field in fields.clone() { + // This field should implement 'SplitBytes'. + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); + } + // The last field should implement 'ParseBytes'. + skeleton.require_bound( + last.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); + + // Construct a 'Self' expression. + let self_expr = match &data.fields { + Fields::Named(_) => { + let names = data.fields.members(); + let exprs = + names.clone().map(|n| format_ident!("field_{}", n)); + quote! { + Self { + #(#names: #exprs,)* + } + } + } + + Fields::Unnamed(_) => { + let exprs = data + .fields + .members() + .map(|n| format_ident!("field_{}", n)); + quote! { + Self(#(#exprs,)*) + } + } + + Fields::Unit => unreachable!(), + }; + + // Define 'parse_bytes()'. + let names = data + .fields + .members() + .take(fields.len()) + .map(|n| format_ident!("field_{}", n)); + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + let last_name = + format_ident!("field_{}", data.fields.members().last().unwrap()); + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + #(let (#names, bytes) = + <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + let #last_name = + <#last_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + ::parse_bytes(bytes)?; + Ok(#self_expr) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + //----------- SplitBytesByRef ------------------------------------------------ #[proc_macro_derive(SplitBytesByRef)] @@ -163,7 +432,14 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { &Self, ::domain::new_base::parse::ParseError, > { - Ok(unsafe { &*bytes.as_ptr().cast::() }) + if bytes.is_empty() { + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } else { + Err(::domain::new_base::parse::ParseError) + } } fn ptr_with_address( From 7e0ef89fefca1a9fd98a3a5bb286cc34f7c228d2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:41:50 +0100 Subject: [PATCH 060/111] [macros] Factor out 'new_lifetime()' --- macros/src/impls.rs | 17 +++++++++++++++-- macros/src/lib.rs | 30 +++--------------------------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index e72ef91ec..0d97309e6 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -1,7 +1,7 @@ //! Helpers for generating `impl` blocks. -use proc_macro2::TokenStream; -use quote::{quote, ToTokens}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; use syn::{punctuated::Punctuated, visit::Visit, *}; //----------- ImplSkeleton --------------------------------------------------- @@ -173,6 +173,19 @@ impl ImplSkeleton { self.where_clause.predicates.push(pred); } } + + /// Generate a unique lifetime with the given prefix. + pub fn new_lifetime(&self, prefix: &str) -> Lifetime { + [format_ident!("{}", prefix)] + .into_iter() + .chain((0u32..).map(|i| format_ident!("{}_{}", prefix, i))) + .find(|id| self.lifetimes.iter().all(|l| l.lifetime.ident != *id)) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap() + } } impl ToTokens for ImplSkeleton { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 33eb6eef5..046b439ca 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -3,7 +3,7 @@ //! [`domain`]: https://docs.rs/domain use proc_macro as pm; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use spanned::Spanned; use syn::*; @@ -40,20 +40,8 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); let mut skeleton = ImplSkeleton::new(&input, false, bound); - // Pick a non-conflicting name for the parsing lifetime. - let lifetime = [format_ident!("bytes")] - .into_iter() - .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) - .find(|id| { - skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) - }) - .map(|ident| Lifetime { - apostrophe: Span::call_site(), - ident, - }) - .unwrap(); - // Add the parsing lifetime to the 'impl'. + let lifetime = skeleton.new_lifetime("bytes"); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -183,20 +171,8 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); let mut skeleton = ImplSkeleton::new(&input, false, bound); - // Pick a non-conflicting name for the parsing lifetime. - let lifetime = [format_ident!("bytes")] - .into_iter() - .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) - .find(|id| { - skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) - }) - .map(|ident| Lifetime { - apostrophe: Span::call_site(), - ident, - }) - .unwrap(); - // Add the parsing lifetime to the 'impl'. + let lifetime = skeleton.new_lifetime("bytes"); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { From b0f6b679e034ebcdd1d02ce8c316af3d63ee104b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 16:06:29 +0100 Subject: [PATCH 061/111] Use parsing trait derives across the 'new_*' codebase --- src/new_base/message.rs | 18 +++++- src/new_base/question.rs | 39 +++--------- src/new_base/record.rs | 30 +++++---- src/new_base/serial.rs | 5 +- src/new_edns/mod.rs | 24 +++++-- src/new_rdata/basic.rs | 132 ++++++++++++++++----------------------- src/new_rdata/edns.rs | 8 +-- src/new_rdata/ipv6.rs | 5 +- 8 files changed, 127 insertions(+), 134 deletions(-) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index e60ae76ff..5d635a9e3 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -5,10 +5,14 @@ use core::fmt; use zerocopy::network_endian::U16; use zerocopy_derive::*; +use domain_macros::*; + //----------- Message -------------------------------------------------------- /// A DNS message. -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive( + FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, ParseBytesByRef, +)] #[repr(C, packed)] pub struct Message { /// The message header. @@ -31,6 +35,10 @@ pub struct Message { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(C)] pub struct Header { @@ -71,6 +79,10 @@ impl fmt::Display for Header { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(transparent)] pub struct HeaderFlags { @@ -232,6 +244,10 @@ impl fmt::Display for HeaderFlags { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(C)] pub struct SectionCounts { diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 029f2839f..8a9ad771f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,23 +2,22 @@ use core::ops::Range; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; +use domain_macros::*; + use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, - parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, - }, + parse::{ParseError, ParseFromMessage, SplitFromMessage}, Message, }; //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone)] +#[derive(Clone, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -96,32 +95,6 @@ where } } -//--- Parsing from bytes - -impl<'a, N> SplitBytes<'a> for Question -where - N: SplitBytes<'a>, -{ - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = N::split_bytes(bytes)?; - let (&qtype, rest) = <&QType>::split_bytes(rest)?; - let (&qclass, rest) = <&QClass>::split_bytes(rest)?; - Ok((Self::new(qname, qtype, qclass), rest)) - } -} - -impl<'a, N> ParseBytes<'a> for Question -where - N: SplitBytes<'a>, -{ - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (qname, rest) = N::split_bytes(bytes)?; - let (&qtype, rest) = <&QType>::split_bytes(rest)?; - let &qclass = <&QClass>::parse_bytes(rest)?; - Ok(Self::new(qname, qtype, qclass)) - } -} - //--- Building into byte strings impl BuildInto for Question @@ -153,7 +126,9 @@ where Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -176,7 +151,9 @@ pub struct QType { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 0b3bab85b..c02780d53 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,19 +5,20 @@ use core::{ ops::{Deref, Range}, }; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, - FromBytes, IntoBytes, SizeError, + FromBytes, IntoBytes, }; use zerocopy_derive::*; +use domain_macros::*; + use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, + SplitBytes, SplitFromMessage, }, Message, }; @@ -104,8 +105,7 @@ where range: Range, ) -> Result { let message = &message.as_bytes()[..range.end]; - let message = Message::ref_from_bytes(message) - .map_err(SizeError::from) + let message = Message::parse_bytes_by_ref(message) .expect("The input range ends past the message header"); let (this, rest) = Self::split_from_message(message, range.start)?; @@ -158,9 +158,9 @@ where { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_bytes(bytes)?; - let (&rtype, rest) = <&RType>::split_bytes(rest)?; - let (&rclass, rest) = <&RClass>::split_bytes(rest)?; - let (&ttl, rest) = <&TTL>::split_bytes(rest)?; + let (rtype, rest) = RType::split_bytes(rest)?; + let (rclass, rest) = RClass::split_bytes(rest)?; + let (ttl, rest) = TTL::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -177,9 +177,9 @@ where { fn parse_bytes(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_bytes(bytes)?; - let (&rtype, rest) = <&RType>::split_bytes(rest)?; - let (&rclass, rest) = <&RClass>::split_bytes(rest)?; - let (&ttl, rest) = <&TTL>::split_bytes(rest)?; + let (rtype, rest) = RType::split_bytes(rest)?; + let (rclass, rest) = RClass::split_bytes(rest)?; + let (ttl, rest) = TTL::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; @@ -232,7 +232,9 @@ where Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -289,7 +291,9 @@ impl RType { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -312,7 +316,9 @@ pub struct RClass { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index f351e1a46..2fe5e8f7c 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,10 +8,11 @@ use core::{ ops::{Add, AddAssign}, }; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::network_endian::U32; use zerocopy_derive::*; +use domain_macros::*; + //----------- Serial --------------------------------------------------------- /// A serial number. @@ -24,7 +25,9 @@ use zerocopy_derive::*; Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 8f9c7de65..d5a1d366f 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,15 +4,16 @@ use core::{fmt, ops::Range}; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; -use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; +use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; +use domain_macros::*; + use crate::{ new_base::{ parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, - SplitFromMessage, + ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, + SplitBytes, SplitFromMessage, }, Message, }, @@ -81,7 +82,11 @@ impl<'a> SplitBytes<'a> for EdnsRecord<'a> { // Split the record size and data. let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); - let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; + if rest.len() < size { + return Err(ParseError); + } + let (options, rest) = rest.split_at(size); + let options = Opt::parse_bytes_by_ref(options)?; Ok(( Self { @@ -109,7 +114,10 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { // Split the record size and data. let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); - let options = Opt::ref_from_bytes_with_elems(rest, size)?; + if rest.len() != size { + return Err(ParseError); + } + let options = Opt::parse_bytes_by_ref(rest)?; Ok(Self { max_udp_payload, @@ -131,7 +139,9 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -209,7 +219,9 @@ pub enum EdnsOption<'b> { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index bfb11b9de..53251ee7f 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,13 +10,14 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, IntoBytes, }; use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ @@ -40,7 +41,9 @@ use crate::new_base::{ Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -112,7 +115,18 @@ impl BuildInto for A { //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct Ns { /// The name of the authoritative server. @@ -141,14 +155,6 @@ impl BuildIntoMessage for Ns { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ns { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for Ns { @@ -163,7 +169,18 @@ impl BuildInto for Ns { //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct CName { /// The canonical name. @@ -192,14 +209,6 @@ impl BuildIntoMessage for CName { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for CName { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for CName { @@ -214,7 +223,7 @@ impl BuildInto for CName { //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, ParseBytes, SplitBytes)] pub struct Soa { /// The name server which provided this zone. pub mname: N, @@ -284,30 +293,6 @@ impl BuildIntoMessage for Soa { } } -//--- Parsing from bytes - -impl<'a, N: SplitBytes<'a>> ParseBytes<'a> for Soa { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (mname, rest) = N::split_bytes(bytes)?; - let (rname, rest) = N::split_bytes(rest)?; - let (&serial, rest) = <&Serial>::split_bytes(rest)?; - let (&refresh, rest) = <&U32>::split_bytes(rest)?; - let (&retry, rest) = <&U32>::split_bytes(rest)?; - let (&expire, rest) = <&U32>::split_bytes(rest)?; - let &minimum = <&U32>::parse_bytes(rest)?; - - Ok(Self { - mname, - rname, - serial, - refresh, - retry, - expire, - minimum, - }) - } -} - //--- Building into byte strings impl BuildInto for Soa { @@ -395,7 +380,18 @@ impl BuildInto for Wks { //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct Ptr { /// The referenced domain name. @@ -424,14 +420,6 @@ impl BuildIntoMessage for Ptr { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ptr { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for Ptr { @@ -446,7 +434,7 @@ impl BuildInto for Ptr { //----------- HInfo ---------------------------------------------------------- /// Information about the host computer. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, ParseBytes, SplitBytes)] pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -484,16 +472,6 @@ impl BuildIntoMessage for HInfo<'_> { } } -//--- Parsing from bytes - -impl<'a> ParseBytes<'a> for HInfo<'a> { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (cpu, rest) = <&CharStr>::split_bytes(bytes)?; - let os = <&CharStr>::parse_bytes(rest)?; - Ok(Self { cpu, os }) - } -} - //--- Building into bytes impl BuildInto for HInfo<'_> { @@ -510,7 +488,18 @@ impl BuildInto for HInfo<'_> { //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(C)] pub struct Mx { /// The preference for this host over others. @@ -551,19 +540,6 @@ impl BuildIntoMessage for Mx { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Mx { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (&preference, rest) = <&U16>::split_bytes(bytes)?; - let exchange = N::parse_bytes(rest)?; - Ok(Self { - preference, - exchange, - }) - } -} - //--- Building into byte strings impl BuildInto for Mx { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 89e146062..4f84ba837 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -4,6 +4,8 @@ use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::build::{ self, BuildInto, BuildIntoMessage, TruncationError, }; @@ -17,13 +19,11 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, )] -#[repr(C)] // 'derive(KnownLayout)' doesn't work with 'repr(transparent)'. +#[repr(transparent)] pub struct Opt { /// The raw serialized options. contents: [u8], diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index fdb2aa674..77df07cc5 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,10 +8,11 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::IntoBytes; use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::build::{ self, BuildInto, BuildIntoMessage, TruncationError, }; @@ -30,7 +31,9 @@ use crate::new_base::build::{ Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] From 4a5d343e7f38b9380ea18f40734fe8ee05b57f9e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 17:54:07 +0100 Subject: [PATCH 062/111] Impl and use derives for 'AsBytes' and 'BuildBytes' --- macros/src/impls.rs | 13 ++- macros/src/lib.rs | 134 +++++++++++++++++++++++++++--- src/new_base/build/builder.rs | 4 +- src/new_base/build/mod.rs | 116 ++++++++++++++++++++++---- src/new_base/charstr.rs | 10 +-- src/new_base/message.rs | 8 +- src/new_base/name/label.rs | 6 +- src/new_base/name/reversed.rs | 19 +++-- src/new_base/question.rs | 32 ++------ src/new_base/record.rs | 46 ++++------- src/new_base/serial.rs | 5 +- src/new_rdata/basic.rs | 148 ++++++---------------------------- src/new_rdata/edns.rs | 26 +----- src/new_rdata/ipv6.rs | 20 +---- src/new_rdata/mod.rs | 33 ++++---- 15 files changed, 332 insertions(+), 288 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 0d97309e6..5e3b884a0 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -21,7 +21,7 @@ pub struct ImplSkeleton { pub unsafety: Option, /// The trait being implemented. - pub bound: Path, + pub bound: Option, /// The type being implemented on. pub subject: Path, @@ -38,7 +38,7 @@ pub struct ImplSkeleton { impl ImplSkeleton { /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. - pub fn new(input: &DeriveInput, unsafety: bool, bound: Path) -> Self { + pub fn new(input: &DeriveInput, unsafety: bool) -> Self { let mut lifetimes = Vec::new(); let mut types = Vec::new(); let mut consts = Vec::new(); @@ -130,7 +130,7 @@ impl ImplSkeleton { types, consts, unsafety, - bound, + bound: None, subject, where_clause, contents, @@ -202,10 +202,15 @@ impl ToTokens for ImplSkeleton { requirements, } = self; + let target = match bound { + Some(bound) => quote!(#bound for #subject), + None => quote!(#subject), + }; + quote! { #unsafety impl<#(#lifetimes,)* #(#types,)* #(#consts,)*> - #bound for #subject + #target #where_clause #contents } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 046b439ca..2885844af 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -36,12 +36,13 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = - parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); - let mut skeleton = ImplSkeleton::new(&input, false, bound); + let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. let lifetime = skeleton.new_lifetime("bytes"); + skeleton.bound = Some( + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -167,12 +168,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = - parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); - let mut skeleton = ImplSkeleton::new(&input, false, bound); + let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. let lifetime = skeleton.new_lifetime("bytes"); + skeleton.bound = Some( + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -311,8 +313,9 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = parse_quote!(::domain::new_base::parse::SplitBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); // Establish bounds on the fields. for field in data.fields.iter() { @@ -429,8 +432,9 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); // Establish bounds on the fields. for field in fields.clone() { @@ -505,3 +509,113 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +//----------- BuildBytes ----------------------------------------------------- + +#[proc_macro_derive(BuildBytes)] +pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, false); + skeleton.bound = + Some(parse_quote!(::domain::new_base::build::BuildBytes)); + + // Get a lifetime for the input buffer. + let lifetime = skeleton.new_lifetime("bytes"); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::build::BuildBytes), + ); + } + + // Define 'build_bytes()'. + let names = data.fields.members(); + let tys = data.fields.iter().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn build_bytes<#lifetime>( + &self, + mut bytes: & #lifetime mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + & #lifetime mut [::domain::__core::primitive::u8], + ::domain::new_base::build::TruncationError, + > { + #(bytes = <#tys as ::domain::new_base::build::BuildBytes> + ::build_bytes(&self.#names, bytes)?;)* + Ok(bytes) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- AsBytes -------------------------------------------------------- + +#[proc_macro_derive(AsBytes)] +pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "AsBytes")?; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::build::AsBytes)); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::build::AsBytes), + ); + } + + // The default implementation of 'as_bytes()' works perfectly. + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 75a9cfc69..9245b9011 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -10,7 +10,7 @@ use zerocopy::{FromBytes, IntoBytes, SizeError}; use crate::new_base::{name::RevName, Header, Message}; -use super::{BuildInto, TruncationError}; +use super::{BuildBytes, TruncationError}; //----------- Builder -------------------------------------------------------- @@ -303,7 +303,7 @@ impl Builder<'_> { name: &RevName, ) -> Result<(), TruncationError> { // TODO: Perform name compression. - name.build_into(self.uninitialized())?; + name.build_bytes(self.uninitialized())?; self.mark_appended(name.len()); Ok(()) } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 108cc76f0..56670e922 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -2,6 +2,8 @@ use core::fmt; +use zerocopy::network_endian::{U16, U32}; + mod builder; pub use builder::{Builder, BuilderContext}; @@ -42,36 +44,38 @@ impl BuildIntoMessage for [u8] { //----------- Low-level building traits -------------------------------------- -/// Building into a byte string. -pub trait BuildInto { - /// Append this value to the byte string. +/// Serializing into a byte string. +pub trait BuildBytes { + /// Serialize into a byte string. /// - /// If the byte string is long enough to fit the message, the remaining - /// (unfilled) part of the byte string is returned. Otherwise, a - /// [`TruncationError`] is returned. - fn build_into<'b>( + /// `self` is serialized into a byte string and written to the given + /// buffer. If the buffer is large enough, the whole object is written + /// and the remaining (unmodified) part of the buffer is returned. + /// + /// if the buffer is too small, a [`TruncationError`] is returned (and + /// parts of the buffer may be modified). + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError>; } -impl BuildInto for &T { - fn build_into<'b>( +impl BuildBytes for &T { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - (**self).build_into(bytes) + T::build_bytes(*self, bytes) } } -impl BuildInto for [u8] { - fn build_into<'b>( +impl BuildBytes for u8 { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - if self.len() <= bytes.len() { - let (bytes, rest) = bytes.split_at_mut(self.len()); - bytes.copy_from_slice(self); + if let Some((elem, rest)) = bytes.split_first_mut() { + *elem = *self; Ok(rest) } else { Err(TruncationError) @@ -79,6 +83,88 @@ impl BuildInto for [u8] { } } +impl BuildBytes for U16 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for U32 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for [T] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +impl BuildBytes for [T; N] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +/// Interpreting a value as a byte string. +/// +/// # Safety +/// +/// A type `T` can soundly implement [`AsBytes`] if and only if: +/// +/// - It has no padding bytes. +/// - It has no interior mutability. +pub unsafe trait AsBytes { + /// Interpret this value as a sequence of bytes. + /// + /// ## Invariants + /// + /// For the statement `let bytes = this.as_bytes();`, + /// + /// - `bytes.as_ptr() as usize == this as *const _ as usize`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + /// + /// The default implementation automatically satisfies these invariants. + fn as_bytes(&self) -> &[u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} + +unsafe impl AsBytes for u8 {} + +unsafe impl AsBytes for [T] {} +unsafe impl AsBytes for [T; N] {} + +unsafe impl AsBytes for U16 {} +unsafe impl AsBytes for U32 {} + //----------- TruncationError ------------------------------------------------ /// A DNS message did not fit in a buffer. diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 57f888c27..2a82e95fa 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -3,10 +3,9 @@ use core::{fmt, ops::Range}; use zerocopy::IntoBytes; -use zerocopy_derive::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, BuildBytes, BuildIntoMessage, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -17,7 +16,6 @@ use super::{ //----------- CharStr -------------------------------------------------------- /// A DNS "character string". -#[derive(Immutable, Unaligned)] #[repr(transparent)] pub struct CharStr { /// The underlying octets. @@ -93,15 +91,15 @@ impl<'a> ParseBytes<'a> for &'a CharStr { //--- Building into byte strings -impl BuildInto for CharStr { - fn build_into<'b>( +impl BuildBytes for CharStr { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { let (length, bytes) = bytes.split_first_mut().ok_or(TruncationError)?; *length = self.octets.len() as u8; - self.octets.build_into(bytes) + self.octets.build_bytes(bytes) } } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 5d635a9e3..3307609bb 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -5,7 +5,7 @@ use core::fmt; use zerocopy::network_endian::U16; use zerocopy_derive::*; -use domain_macros::*; +use domain_macros::{AsBytes, *}; //----------- Message -------------------------------------------------------- @@ -35,6 +35,8 @@ pub struct Message { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -79,6 +81,8 @@ impl fmt::Display for Header { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -244,6 +248,8 @@ impl fmt::Display for HeaderFlags { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 7068e2e15..78ef94008 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -7,16 +7,16 @@ use core::{ iter::FusedIterator, }; -use zerocopy_derive::*; +use domain_macros::AsBytes; -use crate::new_base::parse::{ParseError, ParseBytes, SplitBytes}; +use crate::new_base::parse::{ParseBytes, ParseError, SplitBytes}; //----------- Label ---------------------------------------------------------- /// A label in a domain name. /// /// A label contains up to 63 bytes of arbitrary data. -#[derive(IntoBytes, Immutable, Unaligned)] +#[derive(AsBytes)] #[repr(transparent)] pub struct Label([u8]); diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index ee7b73b9e..6fae3c0f2 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -9,12 +9,12 @@ use core::{ }; use zerocopy::IntoBytes; -use zerocopy_derive::*; use crate::new_base::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, BuildBytes, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -30,7 +30,6 @@ use super::LabelIter; /// use, making many common operations (e.g. comparing and ordering domain /// names) more computationally expensive. A [`RevName`] stores the labels in /// reversed order for more efficient use. -#[derive(Immutable, Unaligned)] #[repr(transparent)] pub struct RevName([u8]); @@ -113,8 +112,8 @@ impl BuildIntoMessage for RevName { //--- Building into byte strings -impl BuildInto for RevName { - fn build_into<'b>( +impl BuildBytes for RevName { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { @@ -213,7 +212,7 @@ impl fmt::Debug for RevName { //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. -#[derive(Clone, Immutable, Unaligned)] +#[derive(Clone)] #[repr(C)] // make layout compatible with '[u8; 256]' pub struct RevNameBuf { /// The position of the root label in the buffer. @@ -422,12 +421,12 @@ impl<'a> ParseBytes<'a> for RevNameBuf { //--- Building into byte strings -impl BuildInto for RevNameBuf { - fn build_into<'b>( +impl BuildBytes for RevNameBuf { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - (**self).build_into(bytes) + (**self).build_bytes(bytes) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 8a9ad771f..4e93951aa 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,13 +2,12 @@ use core::ops::Range; -use zerocopy::{network_endian::U16, IntoBytes}; -use zerocopy_derive::*; +use zerocopy::network_endian::U16; use domain_macros::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ParseError, ParseFromMessage, SplitFromMessage}, Message, @@ -17,7 +16,7 @@ use super::{ //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone, ParseBytes, SplitBytes)] +#[derive(Clone, BuildBytes, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -95,23 +94,6 @@ where } } -//--- Building into byte strings - -impl BuildInto for Question -where - N: BuildInto, -{ - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.qname.build_into(bytes)?; - bytes = self.qtype.as_bytes().build_into(bytes)?; - bytes = self.qclass.as_bytes().build_into(bytes)?; - Ok(bytes) - } -} - //----------- QType ---------------------------------------------------------- /// The type of a question. @@ -124,8 +106,8 @@ where PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -149,8 +131,8 @@ pub struct QType { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_base/record.rs b/src/new_base/record.rs index c02780d53..391b95dee 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -9,12 +9,11 @@ use zerocopy::{ network_endian::{U16, U32}, FromBytes, IntoBytes, }; -use zerocopy_derive::*; use domain_macros::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildBytes, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, @@ -191,25 +190,25 @@ where //--- Building into byte strings -impl BuildInto for Record +impl BuildBytes for Record where - N: BuildInto, - D: BuildInto, + N: BuildBytes, + D: BuildBytes, { - fn build_into<'b>( + fn build_bytes<'b>( &self, mut bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.rname.build_into(bytes)?; - bytes = self.rtype.as_bytes().build_into(bytes)?; - bytes = self.rclass.as_bytes().build_into(bytes)?; - bytes = self.ttl.as_bytes().build_into(bytes)?; + bytes = self.rname.build_bytes(bytes)?; + bytes = self.rtype.as_bytes().build_bytes(bytes)?; + bytes = self.rclass.as_bytes().build_bytes(bytes)?; + bytes = self.ttl.as_bytes().build_bytes(bytes)?; let (size, bytes) = ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; let bytes_len = bytes.len(); - let rest = self.rdata.build_into(bytes)?; + let rest = self.rdata.build_bytes(bytes)?; *size = u16::try_from(bytes_len - rest.len()) .expect("the record data never exceeds 64KiB") .into(); @@ -230,8 +229,8 @@ where PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -289,8 +288,8 @@ impl RType { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -314,8 +313,8 @@ pub struct RClass { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -351,7 +350,7 @@ pub trait ParseRecordData<'a>: Sized { //----------- UnparsedRecordData --------------------------------------------- /// Unparsed DNS record data. -#[derive(Immutable, Unaligned)] +#[derive(AsBytes, BuildBytes)] #[repr(transparent)] pub struct UnparsedRecordData([u8]); @@ -398,17 +397,6 @@ impl BuildIntoMessage for UnparsedRecordData { } } -//--- Building into byte strings - -impl BuildInto for UnparsedRecordData { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.0.build_into(bytes) - } -} - //--- Access to the underlying bytes impl Deref for UnparsedRecordData { diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index 2fe5e8f7c..eaccf32f2 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -9,7 +9,6 @@ use core::{ }; use zerocopy::network_endian::U32; -use zerocopy_derive::*; use domain_macros::*; @@ -23,8 +22,8 @@ use domain_macros::*; PartialEq, Eq, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 53251ee7f..d9e8829ac 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,16 +10,12 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use zerocopy::{ - network_endian::{U16, U32}, - IntoBytes, -}; -use zerocopy_derive::*; +use zerocopy::network_endian::{U16, U32}; use domain_macros::*; use crate::new_base::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildIntoMessage, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -39,8 +35,8 @@ use crate::new_base::{ PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -101,17 +97,6 @@ impl BuildIntoMessage for A { } } -//--- Building into byte strings - -impl BuildInto for A { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} - //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. @@ -124,6 +109,7 @@ impl BuildInto for A { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -155,17 +141,6 @@ impl BuildIntoMessage for Ns { } } -//--- Building into bytes - -impl BuildInto for Ns { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -178,6 +153,7 @@ impl BuildInto for Ns { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -209,21 +185,20 @@ impl BuildIntoMessage for CName { } } -//--- Building into bytes - -impl BuildInto for CName { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, ParseBytes, SplitBytes)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] pub struct Soa { /// The name server which provided this zone. pub mname: N, @@ -293,28 +268,10 @@ impl BuildIntoMessage for Soa { } } -//--- Building into byte strings - -impl BuildInto for Soa { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.mname.build_into(bytes)?; - bytes = self.rname.build_into(bytes)?; - bytes = self.serial.as_bytes().build_into(bytes)?; - bytes = self.refresh.as_bytes().build_into(bytes)?; - bytes = self.retry.as_bytes().build_into(bytes)?; - bytes = self.expire.as_bytes().build_into(bytes)?; - bytes = self.minimum.as_bytes().build_into(bytes)?; - Ok(bytes) - } -} - //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. -#[derive(IntoBytes, Immutable, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C, packed)] pub struct Wks { /// The address of the host providing these services. @@ -366,17 +323,6 @@ impl BuildIntoMessage for Wks { } } -//--- Building into byte strings - -impl BuildInto for Wks { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} - //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. @@ -389,6 +335,7 @@ impl BuildInto for Wks { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -420,21 +367,10 @@ impl BuildIntoMessage for Ptr { } } -//--- Building into bytes - -impl BuildInto for Ptr { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- HInfo ---------------------------------------------------------- /// Information about the host computer. -#[derive(Clone, Debug, PartialEq, Eq, ParseBytes, SplitBytes)] +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes, SplitBytes)] pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -450,6 +386,8 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { message: &'a Message, range: Range, ) -> Result { + use zerocopy::IntoBytes; + message .as_bytes() .get(range) @@ -472,19 +410,6 @@ impl BuildIntoMessage for HInfo<'_> { } } -//--- Building into bytes - -impl BuildInto for HInfo<'_> { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.cpu.build_into(bytes)?; - bytes = self.os.build_into(bytes)?; - Ok(bytes) - } -} - //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. @@ -497,6 +422,7 @@ impl BuildInto for HInfo<'_> { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -540,23 +466,10 @@ impl BuildIntoMessage for Mx { } } -//--- Building into byte strings - -impl BuildInto for Mx { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.preference.as_bytes().build_into(bytes)?; - bytes = self.exchange.build_into(bytes)?; - Ok(bytes) - } -} - //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. -#[derive(IntoBytes, Immutable, Unaligned)] +#[derive(AsBytes, BuildBytes)] #[repr(transparent)] pub struct Txt { /// The text strings, as concatenated [`CharStr`]s. @@ -588,6 +501,8 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { message: &'a Message, range: Range, ) -> Result { + use zerocopy::IntoBytes; + message .as_bytes() .get(range) @@ -622,17 +537,6 @@ impl<'a> ParseBytes<'a> for &'a Txt { } } -//--- Building into byte strings - -impl BuildInto for Txt { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.content.build_into(bytes) - } -} - //--- Formatting impl fmt::Debug for Txt { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 4f84ba837..c53a715a7 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -2,26 +2,15 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use zerocopy_derive::*; - use domain_macros::*; -use crate::new_base::build::{ - self, BuildInto, BuildIntoMessage, TruncationError, -}; +use crate::new_base::build::{self, BuildIntoMessage, TruncationError}; //----------- Opt ------------------------------------------------------------ /// Extended DNS options. #[derive( - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - IntoBytes, - Immutable, - ParseBytesByRef, + PartialEq, Eq, PartialOrd, Ord, Hash, AsBytes, BuildBytes, ParseBytesByRef, )] #[repr(transparent)] pub struct Opt { @@ -42,14 +31,3 @@ impl BuildIntoMessage for Opt { self.contents.build_into_message(builder) } } - -//--- Building into byte strings - -impl BuildInto for Opt { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.contents.build_into(bytes) - } -} diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 77df07cc5..fb3f9d30e 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,13 +8,10 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; -use zerocopy::IntoBytes; -use zerocopy_derive::*; - use domain_macros::*; use crate::new_base::build::{ - self, BuildInto, BuildIntoMessage, TruncationError, + self, AsBytes, BuildIntoMessage, TruncationError, }; //----------- Aaaa ----------------------------------------------------------- @@ -29,8 +26,8 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -90,14 +87,3 @@ impl BuildIntoMessage for Aaaa { self.as_bytes().build_into_message(builder) } } - -//--- Building into byte strings - -impl BuildInto for Aaaa { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 0cf020988..1be038e45 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -2,11 +2,10 @@ use core::ops::Range; -use domain_macros::ParseBytesByRef; -use zerocopy_derive::*; +use domain_macros::*; use crate::new_base::{ - build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, + build::{BuildBytes, BuildIntoMessage, Builder, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -150,23 +149,23 @@ impl BuildIntoMessage for RecordData<'_, N> { } } -impl BuildInto for RecordData<'_, N> { - fn build_into<'b>( +impl BuildBytes for RecordData<'_, N> { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { match self { - Self::A(r) => r.build_into(bytes), - Self::Ns(r) => r.build_into(bytes), - Self::CName(r) => r.build_into(bytes), - Self::Soa(r) => r.build_into(bytes), - Self::Wks(r) => r.build_into(bytes), - Self::Ptr(r) => r.build_into(bytes), - Self::HInfo(r) => r.build_into(bytes), - Self::Txt(r) => r.build_into(bytes), - Self::Aaaa(r) => r.build_into(bytes), - Self::Mx(r) => r.build_into(bytes), - Self::Unknown(_, r) => r.octets.build_into(bytes), + Self::A(r) => r.build_bytes(bytes), + Self::Ns(r) => r.build_bytes(bytes), + Self::CName(r) => r.build_bytes(bytes), + Self::Soa(r) => r.build_bytes(bytes), + Self::Wks(r) => r.build_bytes(bytes), + Self::Ptr(r) => r.build_bytes(bytes), + Self::HInfo(r) => r.build_bytes(bytes), + Self::Txt(r) => r.build_bytes(bytes), + Self::Aaaa(r) => r.build_bytes(bytes), + Self::Mx(r) => r.build_bytes(bytes), + Self::Unknown(_, r) => r.build_bytes(bytes), } } } @@ -174,7 +173,7 @@ impl BuildInto for RecordData<'_, N> { //----------- UnknownRecordData ---------------------------------------------- /// Data for an unknown DNS record type. -#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct UnknownRecordData { /// The unparsed option data. From d1f94ba1bfa6ce21f933642524f81a530d3b3124 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 17:58:22 +0100 Subject: [PATCH 063/111] [macros] Minor fixes as per clippy --- macros/src/lib.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2885844af..605cac3be 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -43,7 +43,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some( parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); - if skeleton.lifetimes.len() > 0 { + if !skeleton.lifetimes.is_empty() { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { #lifetime: #(#lifetimes)+* @@ -105,7 +105,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -175,7 +175,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some( parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), ); - if skeleton.lifetimes.len() > 0 { + if !skeleton.lifetimes.is_empty() { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { #lifetime: #(#lifetimes)+* @@ -252,7 +252,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -360,7 +360,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -501,7 +501,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -564,7 +564,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -611,7 +611,7 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // The default implementation of 'as_bytes()' works perfectly. - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); From c0368ea875d9de4abcdf637e7f40a59022ad8bb2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:52:20 +0100 Subject: [PATCH 064/111] [new_base/parse] Fix missing doc link --- src/new_base/parse/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 493542b66..c4bba79e4 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -154,6 +154,8 @@ impl<'a> ParseBytes<'a> for U32 { /// documented on [`split_bytes_by_ref()`]. An incorrect implementation is /// considered to cause undefined behaviour. /// +/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() +/// /// Note that [`ParseBytesByRef`], required by this trait, also has several /// invariants that need to be considered with care. pub unsafe trait SplitBytesByRef: ParseBytesByRef { From 42d48e007b0fe64218bb454da50131762dc8fb61 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:52:34 +0100 Subject: [PATCH 065/111] [macros] Fix no-fields output for 'ParseBytes' --- macros/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 605cac3be..67285d420 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -160,7 +160,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { if bytes.is_empty() { Ok(Self #fields) } else { - Err() + Err(::domain::new_base::parse::ParseError) } } } From 90b531a2a540e6945a279e37558944c52f5e7f55 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:53:00 +0100 Subject: [PATCH 066/111] [new_base/serial] Support measuring unix time --- src/new_base/serial.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index eaccf32f2..4258c4b22 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -32,6 +32,21 @@ use domain_macros::*; #[repr(transparent)] pub struct Serial(U32); +//--- Construction + +impl Serial { + /// Measure the current time (in seconds) in serial number space. + #[cfg(feature = "std")] + pub fn unix_time() -> Self { + use std::time::SystemTime; + + let time = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("The current time is after the Unix Epoch"); + Self::from(time.as_secs() as u32) + } +} + //--- Addition impl Add for Serial { From f8fc52583788b4ba1ad053bdd0e126079aaffa4b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:53:13 +0100 Subject: [PATCH 067/111] [new_edns] Implement DNS cookie support --- src/new_edns/cookie.rs | 242 +++++++++++++++++++++++++++++++++++++++++ src/new_edns/mod.rs | 16 ++- 2 files changed, 252 insertions(+), 6 deletions(-) create mode 100644 src/new_edns/cookie.rs diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs new file mode 100644 index 000000000..466e8c606 --- /dev/null +++ b/src/new_edns/cookie.rs @@ -0,0 +1,242 @@ +//! DNS cookies. +//! +//! See [RFC 7873] and [RFC 9018]. +//! +//! [RFC 7873]: https://datatracker.ietf.org/doc/html/rfc7873 +//! [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + +use core::fmt; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use core::ops::Range; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use std::net::IpAddr; + +use domain_macros::*; + +use crate::new_base::Serial; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use crate::new_base::build::{AsBytes, TruncationError}; + +//----------- CookieRequest -------------------------------------------------- + +/// A request for a DNS cookie. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct CookieRequest { + /// The octets of the request. + pub octets: [u8; 8], +} + +//--- Construction + +impl CookieRequest { + /// Construct a random [`CookieRequest`]. + #[cfg(feature = "rand")] + pub fn random() -> Self { + rand::random::<[u8; 8]>().into() + } +} + +//--- Interaction + +impl CookieRequest { + /// Build a [`Cookie`] in response to this request. + /// + /// A 24-byte version-1 interoperable cookie will be generated and written + /// to the given buffer. If the buffer is big enough, the remaining part + /// of the buffer is returned. + #[cfg(all(feature = "std", feature = "siphasher"))] + pub fn respond_into<'b>( + &self, + addr: IpAddr, + secret: &[u8; 16], + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + use core::hash::Hasher; + + use siphasher::sip::SipHasher24; + + use crate::new_base::build::BuildBytes; + + // Build and hash the cookie simultaneously. + let mut hasher = SipHasher24::new_with_key(secret); + + bytes = self.build_bytes(bytes)?; + hasher.write(self.as_bytes()); + + // The version number and the reserved octets. + bytes = [1, 0, 0, 0].build_bytes(bytes)?; + hasher.write(&[1, 0, 0, 0]); + + let timestamp = Serial::unix_time(); + bytes = timestamp.build_bytes(bytes)?; + hasher.write(timestamp.as_bytes()); + + match addr { + IpAddr::V4(addr) => hasher.write(&addr.octets()), + IpAddr::V6(addr) => hasher.write(&addr.octets()), + } + + let hash = hasher.finish().to_le_bytes(); + bytes = hash.build_bytes(bytes)?; + + Ok(bytes) + } +} + +//--- Conversion to and from octets + +impl From<[u8; 8]> for CookieRequest { + fn from(value: [u8; 8]) -> Self { + Self { octets: value } + } +} + +impl From for [u8; 8] { + fn from(value: CookieRequest) -> Self { + value.octets + } +} + +//--- Formatting + +impl fmt::Debug for CookieRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CookieRequest({})", self) + } +} + +impl fmt::Display for CookieRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:016X}", u64::from_be_bytes(self.octets)) + } +} + +//----------- Cookie --------------------------------------------------------- + +/// A DNS cookie. +#[derive(PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef)] +#[repr(C)] +pub struct Cookie { + /// The request for this cookie. + request: CookieRequest, + + /// The version number of this cookie. + version: u8, + + /// Reserved bytes in the cookie format. + reversed: [u8; 3], + + /// When this cookie was made. + timestamp: Serial, + + /// The hash of this cookie. + hash: [u8], +} + +//--- Inspection + +impl Cookie { + /// The underlying cookie request. + pub fn request(&self) -> &CookieRequest { + &self.request + } + + /// The version number of this interoperable cookie. + /// + /// Assuming this is an interoperable cookie, as specified by [RFC 9018], + /// the 1-byte version number of the cookie is returned. Currently, only + /// version 1 has been specified. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + pub fn version(&self) -> u8 { + self.version + } + + /// When this interoperable cookie was produced. + /// + /// Assuming this is an interoperable cookie, as specified by [RFC 9018], + /// the 4-byte timestamp of the cookie is returned. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + pub fn timestamp(&self) -> Serial { + self.timestamp + } +} + +//--- Interaction + +impl Cookie { + /// Verify this cookie. + /// + /// This cookie is verified as a 24-byte version-1 interoperable cookie, + /// as specified by [RFC 9018]. A 16-byte secret is used to generate a + /// hash for this cookie, based on its fields and the IP address of the + /// client which used it. If the cookie was generated in the given time + /// period, and the generated hash matches the hash in the cookie, it is + /// valid. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + #[cfg(all(feature = "std", feature = "siphasher"))] + pub fn verify( + &self, + addr: IpAddr, + secret: &[u8; 16], + validity: Range, + ) -> Result<(), CookieError> { + use core::hash::Hasher; + + use siphasher::sip::SipHasher24; + + // Check basic features of the cookie. + if self.version != 1 + || self.hash.len() != 8 + || !validity.contains(&self.timestamp) + { + return Err(CookieError); + } + + // Check the cookie hash. + let mut hasher = SipHasher24::new_with_key(secret); + hasher.write(&self.as_bytes()[..16]); + match addr { + IpAddr::V4(addr) => hasher.write(&addr.octets()), + IpAddr::V6(addr) => hasher.write(&addr.octets()), + } + + if self.hash == hasher.finish().to_le_bytes() { + Ok(()) + } else { + Err(CookieError) + } + } +} + +//----------- CookieError ---------------------------------------------------- + +/// An invalid [`Cookie`] was encountered. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct CookieError; + +//--- Formatting + +impl fmt::Display for CookieError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A DNS cookie could not be verified") + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index d5a1d366f..e72529fb3 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -5,7 +5,6 @@ use core::{fmt, ops::Range}; use zerocopy::{network_endian::U16, IntoBytes}; -use zerocopy_derive::*; use domain_macros::*; @@ -20,6 +19,11 @@ use crate::{ new_rdata::Opt, }; +//----------- EDNS option modules -------------------------------------------- + +mod cookie; +pub use cookie::{Cookie, CookieRequest}; + //----------- EdnsRecord ----------------------------------------------------- /// An Extended DNS record. @@ -137,8 +141,8 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { Clone, Default, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -217,8 +221,8 @@ pub enum EdnsOption<'b> { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -233,7 +237,7 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct UnknownOption { /// The unparsed option data. From d4455874ec2ee7ea416549643a6b28e53e56f96d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 3 Jan 2025 12:13:40 +0100 Subject: [PATCH 068/111] [macros] Factor out struct inspection and building --- macros/src/data.rs | 159 +++++++++++++++++++ macros/src/impls.rs | 16 ++ macros/src/lib.rs | 375 ++++++++++++++++++-------------------------- 3 files changed, 326 insertions(+), 224 deletions(-) create mode 100644 macros/src/data.rs diff --git a/macros/src/data.rs b/macros/src/data.rs new file mode 100644 index 000000000..6a0788b3e --- /dev/null +++ b/macros/src/data.rs @@ -0,0 +1,159 @@ +//! Working with structs, enums, and unions. + +use std::ops::Deref; + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{spanned::Spanned, *}; + +//----------- Struct --------------------------------------------------------- + +/// A defined 'struct'. +pub struct Struct { + /// The identifier for this 'struct'. + ident: Ident, + + /// The fields in this 'struct'. + fields: Fields, +} + +impl Struct { + /// Construct a [`Struct`] for a 'Self'. + pub fn new_as_self(fields: &Fields) -> Self { + Self { + ident: ::default().into(), + fields: fields.clone(), + } + } + + /// Whether this 'struct' has no fields. + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// The number of fields in this 'struct'. + pub fn num_fields(&self) -> usize { + self.fields.len() + } + + /// The fields of this 'struct'. + pub fn fields(&self) -> impl Iterator + '_ { + self.fields.iter() + } + + /// The sized fields of this 'struct'. + pub fn sized_fields(&self) -> impl Iterator + '_ { + self.fields().take(self.num_fields() - 1) + } + + /// The unsized field of this 'struct'. + pub fn unsized_field(&self) -> Option<&Field> { + self.fields.iter().next_back() + } + + /// The names of the fields of this 'struct'. + pub fn members(&self) -> impl Iterator + '_ { + self.fields + .iter() + .enumerate() + .map(|(i, f)| make_member(i, f)) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_members(&self) -> impl Iterator + '_ { + self.members().take(self.num_fields() - 1) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_member(&self) -> Option { + self.fields + .iter() + .next_back() + .map(|f| make_member(self.num_fields() - 1, f)) + } + + /// Construct a builder for this 'struct'. + pub fn builder Ident>( + &self, + f: F, + ) -> StructBuilder<'_, F> { + StructBuilder { + target: self, + var_fn: f, + } + } +} + +/// Construct a [`Member`] from a field and index. +fn make_member(index: usize, field: &Field) -> Member { + match &field.ident { + Some(ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { + index: index as u32, + span: field.ty.span(), + }), + } +} + +//----------- StructBuilder -------------------------------------------------- + +/// A means of constructing a 'struct'. +pub struct StructBuilder<'a, F: Fn(Member) -> Ident> { + /// The 'struct' being constructed. + target: &'a Struct, + + /// A map from field names to constructing variables. + var_fn: F, +} + +impl Ident> StructBuilder<'_, F> { + /// The initializing variables for this 'struct'. + pub fn init_vars(&self) -> impl Iterator + '_ { + self.members().map(&self.var_fn) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_init_vars(&self) -> impl Iterator + '_ { + self.sized_members().map(&self.var_fn) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_init_var(&self) -> Option { + self.unsized_member().map(&self.var_fn) + } +} + +impl Ident> Deref for StructBuilder<'_, F> { + type Target = Struct; + + fn deref(&self) -> &Self::Target { + self.target + } +} + +impl Ident> ToTokens for StructBuilder<'_, F> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + match self.fields { + Fields::Named(_) => { + let members = self.members(); + let init_vars = self.init_vars(); + quote! { + #ident { #(#members: #init_vars),* } + } + } + + Fields::Unnamed(_) => { + let init_vars = self.init_vars(); + quote! { + #ident ( #(#init_vars),* ) + } + } + + Fields::Unit => { + quote! { #ident } + } + } + .to_tokens(tokens); + } +} diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 5e3b884a0..2d9724f0e 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -186,6 +186,22 @@ impl ImplSkeleton { }) .unwrap() } + + /// Generate a unique lifetime parameter with the given prefix and bounds. + pub fn new_lifetime_param( + &self, + prefix: &str, + bounds: impl IntoIterator, + ) -> (Lifetime, LifetimeParam) { + let lifetime = self.new_lifetime(prefix); + let mut bounds = bounds.into_iter().peekable(); + let param = if bounds.peek().is_some() { + parse_quote! { #lifetime: #(#bounds)+* } + } else { + parse_quote! { #lifetime } + }; + (lifetime, param) + } } impl ToTokens for ImplSkeleton { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 67285d420..99d209fff 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -4,13 +4,15 @@ use proc_macro as pm; use proc_macro2::TokenStream; -use quote::{format_ident, quote, ToTokens}; -use spanned::Spanned; +use quote::{format_ident, ToTokens}; use syn::*; mod impls; use impls::ImplSkeleton; +mod data; +use data::Struct; + mod repr; use repr::Repr; @@ -39,58 +41,30 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. - let lifetime = skeleton.new_lifetime("bytes"); + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); skeleton.bound = Some( parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); - if !skeleton.lifetimes.is_empty() { - let lifetimes = skeleton.lifetimes.iter(); - let param = parse_quote! { - #lifetime: #(#lifetimes)+* - }; - skeleton.lifetimes.push(param); - } else { - skeleton.lifetimes.push(parse_quote! { #lifetime }) - } + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); } - // Construct a 'Self' expression. - let self_expr = match &data.fields { - Fields::Named(_) => { - let names = data.fields.members(); - let exprs = - names.clone().map(|n| format_ident!("field_{}", n)); - quote! { - Self { - #(#names: #exprs,)* - } - } - } - - Fields::Unnamed(_) => { - let exprs = data - .fields - .members() - .map(|n| format_ident!("field_{}", n)); - quote! { - Self(#(#exprs,)*) - } - } - - Fields::Unit => quote! { Self }, - }; - // Define 'parse_bytes()'. - let names = - data.fields.members().map(|n| format_ident!("field_{}", n)); - let tys = data.fields.iter().map(|f| &f.ty); + let init_vars = builder.init_vars(); + let tys = data.fields().map(|f| &f.ty); skeleton.contents.stmts.push(parse_quote! { fn split_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], @@ -98,10 +72,10 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { (Self, & #lifetime [::domain::__core::primitive::u8]), ::domain::new_base::parse::ParseError, > { - #(let (#names, bytes) = + #(let (#init_vars, bytes) = <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* - Ok((#self_expr, bytes)) + Ok((#builder, bytes)) } }); @@ -135,106 +109,62 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { } }; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - assert!(input.generics.params.is_empty()); - let where_clause = input.generics.where_clause; - let name = input.ident; - - // This will tokenize to '{}', '()', or ''. - let fields = data.fields.to_token_stream(); - - return Ok(quote! { - impl <'bytes> - ::domain::new_base::parse::ParseBytes<'bytes> - for #name - #where_clause { - fn parse_bytes( - bytes: &'bytes [::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - Self, - ::domain::new_base::parse::ParseError, - > { - if bytes.is_empty() { - Ok(Self #fields) - } else { - Err(::domain::new_base::parse::ParseError) - } - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. - let lifetime = skeleton.new_lifetime("bytes"); + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); skeleton.bound = Some( parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), ); - if !skeleton.lifetimes.is_empty() { - let lifetimes = skeleton.lifetimes.iter(); - let param = parse_quote! { - #lifetime: #(#lifetimes)+* - }; - skeleton.lifetimes.push(param); - } else { - skeleton.lifetimes.push(parse_quote! { #lifetime }) - } + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); // Establish bounds on the fields. - for field in fields.clone() { - // This field should implement 'SplitBytes'. + for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); } - // The last field should implement 'ParseBytes'. - skeleton.require_bound( - last.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), - ); + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); + } - // Construct a 'Self' expression. - let self_expr = match &data.fields { - Fields::Named(_) => { - let names = data.fields.members(); - let exprs = - names.clone().map(|n| format_ident!("field_{}", n)); - quote! { - Self { - #(#names: #exprs,)* + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + Ok(#builder) + } else { + Err(::domain::new_base::parse::ParseError) } } - } - - Fields::Unnamed(_) => { - let exprs = data - .fields - .members() - .map(|n| format_ident!("field_{}", n)); - quote! { - Self(#(#exprs,)*) - } - } + }); - Fields::Unit => unreachable!(), - }; + return Ok(skeleton.into_token_stream()); + } // Define 'parse_bytes()'. - let names = data - .fields - .members() - .take(fields.len()) - .map(|n| format_ident!("field_{}", n)); - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; - let last_name = - format_ident!("field_{}", data.fields.members().last().unwrap()); + let init_vars = builder.sized_init_vars(); + let tys = builder.sized_fields().map(|f| &f.ty); + let unsized_ty = &builder.unsized_field().unwrap().ty; + let unsized_init_var = builder.unsized_init_var().unwrap(); skeleton.contents.stmts.push(parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], @@ -242,13 +172,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { Self, ::domain::new_base::parse::ParseError, > { - #(let (#names, bytes) = + #(let (#init_vars, bytes) = <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* - let #last_name = - <#last_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + let #unsized_init_var = + <#unsized_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> ::parse_bytes(bytes)?; - Ok(#self_expr) + Ok(#builder) } }); @@ -284,50 +214,47 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let _ = Repr::determine(&input.attrs, "SplitBytesByRef")?; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - let (impl_generics, ty_generics, where_clause) = - input.generics.split_for_impl(); - let name = input.ident; - - return Ok(quote! { - unsafe impl #impl_generics - ::domain::new_base::parse::SplitBytesByRef - for #name #ty_generics - #where_clause { - fn split_bytes_by_ref( - bytes: &[::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, - > { - Ok(( - unsafe { &*bytes.as_ptr().cast::() }, - bytes, - )) - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytesByRef), ); } + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + Ok(( + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + unsafe { &*bytes.as_ptr().cast::() }, + bytes, + )) + } + }); + + return Ok(skeleton.into_token_stream()); + } + // Define 'split_bytes_by_ref()'. - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; skeleton.contents.stmts.push(parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], @@ -340,10 +267,10 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { <#tys as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let (last, rest) = - <#last_ty as ::domain::new_base::parse::SplitBytesByRef> + <#unsized_ty as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?; let ptr = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -392,67 +319,63 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let _ = Repr::determine(&input.attrs, "ParseBytesByRef")?; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - let (impl_generics, ty_generics, where_clause) = - input.generics.split_for_impl(); - let name = input.ident; - - return Ok(quote! { - unsafe impl #impl_generics - ::domain::new_base::parse::ParseBytesByRef - for #name #ty_generics - #where_clause { - fn parse_bytes_by_ref( - bytes: &[::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - &Self, - ::domain::new_base::parse::ParseError, - > { - if bytes.is_empty() { - // SAFETY: 'Self' is a 'struct' with no fields, - // and so has size 0 and alignment 1. It can be - // constructed at any address. - Ok(unsafe { &*bytes.as_ptr().cast::() }) - } else { - Err(::domain::new_base::parse::ParseError) - } - } - - fn ptr_with_address( - &self, - addr: *const (), - ) -> *const Self { - addr.cast() - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Establish bounds on the fields. - for field in fields.clone() { - // This field should implement 'SplitBytesByRef'. + for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytesByRef), ); } - // The last field should implement 'ParseBytesByRef'. - skeleton.require_bound( - last.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytesByRef), - ); + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytesByRef), + ); + } + + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } else { + Err(::domain::new_base::parse::ParseError) + } + } + }); + + skeleton.contents.stmts.push(parse_quote! { + fn ptr_with_address( + &self, + addr: *const (), + ) -> *const Self { + addr.cast() + } + }); + + return Ok(skeleton.into_token_stream()); + } // Define 'parse_bytes_by_ref()'. - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; skeleton.contents.stmts.push(parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], @@ -465,10 +388,10 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { <#tys as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let last = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::parse_bytes_by_ref(bytes)?; let ptr = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -486,17 +409,11 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }); // Define 'ptr_with_address()'. - let last_name = match last.ident.as_ref() { - Some(ident) => Member::Named(ident.clone()), - None => Member::Unnamed(Index { - index: data.fields.len() as u32 - 1, - span: last.ty.span(), - }), - }; + let unsized_member = data.unsized_member(); skeleton.contents.stmts.push(parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> - ::ptr_with_address(&self.#last_name, addr) + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(&self.#unsized_member, addr) as *const Self } }); @@ -536,11 +453,14 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some(parse_quote!(::domain::new_base::build::BuildBytes)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Get a lifetime for the input buffer. let lifetime = skeleton.new_lifetime("bytes"); // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::build::BuildBytes), @@ -548,8 +468,8 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { } // Define 'build_bytes()'. - let names = data.fields.members(); - let tys = data.fields.iter().map(|f| &f.ty); + let members = data.members(); + let tys = data.fields().map(|f| &f.ty); skeleton.contents.stmts.push(parse_quote! { fn build_bytes<#lifetime>( &self, @@ -559,7 +479,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { ::domain::new_base::build::TruncationError, > { #(bytes = <#tys as ::domain::new_base::build::BuildBytes> - ::build_bytes(&self.#names, bytes)?;)* + ::build_bytes(&self.#members, bytes)?;)* Ok(bytes) } }); @@ -619,3 +539,10 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +//----------- Utility Functions ---------------------------------------------- + +/// Add a `field_` prefix to member names. +fn field_prefixed(member: Member) -> Ident { + format_ident!("field_{}", member) +} From 21dfd3daa6a612b7b662af54736ab7a44316342f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 14:00:43 +0100 Subject: [PATCH 069/111] [new_edns] Impl RFC 8914 "Extended DNS errors" --- src/new_base/parse/mod.rs | 17 +++ src/new_edns/ext_err.rs | 210 ++++++++++++++++++++++++++++++++++++++ src/new_edns/mod.rs | 3 + 3 files changed, 230 insertions(+) create mode 100644 src/new_edns/ext_err.rs diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index c4bba79e4..32bc3627a 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -321,6 +321,23 @@ unsafe impl ParseBytesByRef for [u8] { } } +unsafe impl ParseBytesByRef for str { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + core::str::from_utf8(bytes).map_err(|_| ParseError) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + // NOTE: The Rust Reference indicates that 'str' has the same layout + // as '[u8]' [1]. This is also the most natural layout for it. Since + // there's no way to construct a '*const str' from raw parts, we will + // just construct a raw slice and transmute it. + // + // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout + + self.as_bytes().ptr_with_address(addr) as *const Self + } +} + unsafe impl SplitBytesByRef for [T; N] { fn split_bytes_by_ref( mut bytes: &[u8], diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs new file mode 100644 index 000000000..030df6814 --- /dev/null +++ b/src/new_edns/ext_err.rs @@ -0,0 +1,210 @@ +//! Extended DNS errors. +//! +//! See [RFC 8914](https://datatracker.ietf.org/doc/html/rfc8914). + +use core::fmt; + +use domain_macros::*; + +use zerocopy::network_endian::U16; + +//----------- ExtError ------------------------------------------------------- + +/// An extended DNS error. +#[derive(ParseBytesByRef)] +#[repr(C)] +pub struct ExtError { + /// The error code. + pub code: ExtErrorCode, + + /// A human-readable description of the error. + text: str, +} + +impl ExtError { + /// A human-readable description of the error. + pub fn text(&self) -> Option<&str> { + if !self.text.is_empty() { + Some(self.text.strip_suffix('\0').unwrap_or(&self.text)) + } else { + None + } + } +} + +//----------- ExtErrorCode --------------------------------------------------- + +/// The code for an extended DNS error. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct ExtErrorCode { + inner: U16, +} + +//--- Associated Constants + +impl ExtErrorCode { + const fn new(inner: u16) -> Self { + Self { + inner: U16::new(inner), + } + } + + /// An unspecified extended error. + /// + /// This should be used when there is no other appropriate error code. + pub const OTHER: Self = Self::new(0); + + /// DNSSEC validation failed because a DNSKEY used an unknown algorithm. + pub const BAD_DNSKEY_ALG: Self = Self::new(1); + + /// DNSSEC validation failed because a DS set used an unknown algorithm. + pub const BAD_DS_ALG: Self = Self::new(2); + + /// An up-to-date answer could not be retrieved in time. + pub const STALE_ANSWER: Self = Self::new(3); + + /// Policy dictated that a forged answer be returned. + pub const FORGED_ANSWER: Self = Self::new(4); + + /// The DNSSEC validity of the answer could not be determined. + pub const DNSSEC_INDETERMINATE: Self = Self::new(5); + + /// The answer was invalid as per DNSSEC. + pub const DNSSEC_BOGUS: Self = Self::new(6); + + /// The DNSSEC signature of the answer expired. + pub const SIG_EXPIRED: Self = Self::new(7); + + /// The DNSSEC signature of the answer is valid in the future. + pub const SIG_FUTURE: Self = Self::new(8); + + /// DNSSEC validation failed because a DNSKEY record was missing. + pub const DNSKEY_MISSING: Self = Self::new(9); + + /// DNSSEC validation failed because RRSIGs were unexpectedly missing. + pub const RRSIGS_MISSING: Self = Self::new(10); + + /// DNSSEC validation failed because a DNSKEY wasn't a ZSK. + pub const NOT_ZSK: Self = Self::new(11); + + /// DNSSEC validation failed because an NSEC(3) record could not be found. + pub const NSEC_MISSING: Self = Self::new(12); + + /// The server failure error was cached from an upstream. + pub const CACHED_ERROR: Self = Self::new(13); + + /// The server is not ready to serve requests. + pub const NOT_READY: Self = Self::new(14); + + /// The request is blocked by internal policy. + pub const BLOCKED: Self = Self::new(15); + + /// The request is blocked by external policy. + pub const CENSORED: Self = Self::new(16); + + /// The request is blocked by the client's own filters. + pub const FILTERED: Self = Self::new(17); + + /// The client is prohibited from making requests. + pub const PROHIBITED: Self = Self::new(18); + + /// An up-to-date answer could not be retrieved in time. + pub const STALE_NXDOMAIN: Self = Self::new(19); + + /// The request cannot be answered authoritatively. + pub const NOT_AUTHORITATIVE: Self = Self::new(20); + + /// The request / operation is not supported. + pub const NOT_SUPPORTED: Self = Self::new(21); + + /// No upstream authorities answered the request (in time). + pub const NO_REACHABLE_AUTHORITY: Self = Self::new(22); + + /// An unrecoverable network error occurred. + pub const NETWORK_ERROR: Self = Self::new(23); + + /// The server's local zone data is invalid. + pub const INVALID_DATA: Self = Self::new(24); + + /// An impure operation was stated in a DNS-over-QUIC 0-RTT packet. + /// + /// See [RFC 9250](https://datatracker.ietf.org/doc/html/rfc9250). + pub const TOO_EARLY: Self = Self::new(26); + + /// DNSSEC validation failed because an NSEC3 parameter was unsupported. + pub const BAD_NSEC3_ITERS: Self = Self::new(27); +} + +//--- Inspection + +impl ExtErrorCode { + /// Whether this is a private-use code. + /// + /// Private-use codes occupy the range 49152 to 65535 (inclusive). + pub fn is_private(&self) -> bool { + self.inner >= 49152 + } +} + +//--- Formatting + +impl fmt::Debug for ExtErrorCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let text = match *self { + Self::OTHER => "other", + Self::BAD_DNSKEY_ALG => "unsupported DNSKEY algorithm", + Self::BAD_DS_ALG => "unspported DS digest type", + Self::STALE_ANSWER => "stale answer", + Self::FORGED_ANSWER => "forged answer", + Self::DNSSEC_INDETERMINATE => "DNSSEC indeterminate", + Self::DNSSEC_BOGUS => "DNSSEC bogus", + Self::SIG_EXPIRED => "signature expired", + Self::SIG_FUTURE => "signature not yet valid", + Self::DNSKEY_MISSING => "DNSKEY missing", + Self::RRSIGS_MISSING => "RRSIGs missing", + Self::NOT_ZSK => "no zone key bit set", + Self::NSEC_MISSING => "nsec missing", + Self::CACHED_ERROR => "cached error", + Self::NOT_READY => "not ready", + Self::BLOCKED => "blocked", + Self::CENSORED => "censored", + Self::FILTERED => "filtered", + Self::PROHIBITED => "prohibited", + Self::STALE_NXDOMAIN => "stale NXDOMAIN answer", + Self::NOT_AUTHORITATIVE => "not authoritative", + Self::NOT_SUPPORTED => "not supported", + Self::NO_REACHABLE_AUTHORITY => "no reachable authority", + Self::NETWORK_ERROR => "network error", + Self::INVALID_DATA => "invalid data", + Self::TOO_EARLY => "too early", + Self::BAD_NSEC3_ITERS => "unsupported NSEC3 iterations value", + + _ => { + return f + .debug_tuple("ExtErrorCode") + .field(&self.inner.get()) + .finish(); + } + }; + + f.debug_tuple("ExtErrorCode") + .field(&self.inner.get()) + .field(&text) + .finish() + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index e72529fb3..b14170919 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -24,6 +24,9 @@ use crate::{ mod cookie; pub use cookie::{Cookie, CookieRequest}; +mod ext_err; +pub use ext_err::{ExtError, ExtErrorCode}; + //----------- EdnsRecord ----------------------------------------------------- /// An Extended DNS record. From 16596213071f965c09c17a5053ae5e44a5ebb043 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 14:42:33 +0100 Subject: [PATCH 070/111] [new_edns] Impl parsing/building for 'EdnsOption' --- src/new_base/build/mod.rs | 10 +++ src/new_edns/cookie.rs | 4 +- src/new_edns/ext_err.rs | 26 +++++-- src/new_edns/mod.rs | 144 +++++++++++++++++++++++++++++++++++++- src/new_rdata/mod.rs | 2 +- 5 files changed, 175 insertions(+), 11 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 56670e922..548b2d8fd 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -83,6 +83,15 @@ impl BuildBytes for u8 { } } +impl BuildBytes for str { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + impl BuildBytes for U16 { fn build_bytes<'b>( &self, @@ -158,6 +167,7 @@ pub unsafe trait AsBytes { } unsafe impl AsBytes for u8 {} +unsafe impl AsBytes for str {} unsafe impl AsBytes for [T] {} unsafe impl AsBytes for [T; N] {} diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 466e8c606..1a815e615 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -130,7 +130,9 @@ impl fmt::Display for CookieRequest { //----------- Cookie --------------------------------------------------------- /// A DNS cookie. -#[derive(PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive( + Debug, PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef, +)] #[repr(C)] pub struct Cookie { /// The request for this cookie. diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs index 030df6814..6858aa713 100644 --- a/src/new_edns/ext_err.rs +++ b/src/new_edns/ext_err.rs @@ -11,7 +11,7 @@ use zerocopy::network_endian::U16; //----------- ExtError ------------------------------------------------------- /// An extended DNS error. -#[derive(ParseBytesByRef)] +#[derive(AsBytes, ParseBytesByRef)] #[repr(C)] pub struct ExtError { /// The error code. @@ -32,6 +32,17 @@ impl ExtError { } } +//--- Formatting + +impl fmt::Debug for ExtError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtError") + .field("code", &self.code) + .field("text", &self.text()) + .finish() + } +} + //----------- ExtErrorCode --------------------------------------------------- /// The code for an extended DNS error. @@ -52,15 +63,16 @@ impl ExtError { )] #[repr(transparent)] pub struct ExtErrorCode { - inner: U16, + /// The error code. + pub code: U16, } //--- Associated Constants impl ExtErrorCode { - const fn new(inner: u16) -> Self { + const fn new(code: u16) -> Self { Self { - inner: U16::new(inner), + code: U16::new(code), } } @@ -157,7 +169,7 @@ impl ExtErrorCode { /// /// Private-use codes occupy the range 49152 to 65535 (inclusive). pub fn is_private(&self) -> bool { - self.inner >= 49152 + self.code >= 49152 } } @@ -197,13 +209,13 @@ impl fmt::Debug for ExtErrorCode { _ => { return f .debug_tuple("ExtErrorCode") - .field(&self.inner.get()) + .field(&self.code.get()) .finish(); } }; f.debug_tuple("ExtErrorCode") - .field(&self.inner.get()) + .field(&self.code.get()) .field(&text) .finish() } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index b14170919..3b3ab2a80 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -10,6 +10,7 @@ use domain_macros::*; use crate::{ new_base::{ + build::{AsBytes, BuildBytes, TruncationError}, parse::{ ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -208,17 +209,123 @@ impl fmt::Debug for EdnsFlags { #[derive(Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { + /// A request for a DNS cookie. + CookieRequest(&'b CookieRequest), + + /// A DNS cookie. + Cookie(&'b Cookie), + + /// An extended DNS error. + ExtError(&'b ExtError), + /// An unknown option. Unknown(OptionCode, &'b UnknownOption), } +//--- Inspection + +impl EdnsOption<'_> { + /// The code for this option. + pub fn code(&self) -> OptionCode { + match self { + Self::CookieRequest(_) => OptionCode::COOKIE, + Self::Cookie(_) => OptionCode::COOKIE, + Self::ExtError(_) => OptionCode::EXT_ERROR, + Self::Unknown(code, _) => *code, + } + } +} + +//--- Parsing from bytes + +impl<'b> ParseBytes<'b> for EdnsOption<'b> { + fn parse_bytes(bytes: &'b [u8]) -> Result { + let (code, rest) = OptionCode::split_bytes(bytes)?; + let (size, rest) = U16::split_bytes(rest)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + + match code { + OptionCode::COOKIE => match size.get() { + 8 => CookieRequest::parse_bytes_by_ref(rest) + .map(Self::CookieRequest), + 16..=40 => Cookie::parse_bytes_by_ref(rest).map(Self::Cookie), + _ => Err(ParseError), + }, + + OptionCode::EXT_ERROR => { + ExtError::parse_bytes_by_ref(rest).map(Self::ExtError) + } + + _ => { + let data = UnknownOption::parse_bytes_by_ref(rest)?; + Ok(Self::Unknown(code, data)) + } + } + } +} + +impl<'b> SplitBytes<'b> for EdnsOption<'b> { + fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + let (code, rest) = OptionCode::split_bytes(bytes)?; + let (size, rest) = U16::split_bytes(rest)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (bytes, rest) = rest.split_at(size.get() as usize); + + match code { + OptionCode::COOKIE => match size.get() { + 8 => CookieRequest::parse_bytes_by_ref(bytes) + .map(Self::CookieRequest), + 16..=40 => { + Cookie::parse_bytes_by_ref(bytes).map(Self::Cookie) + } + _ => Err(ParseError), + }, + + OptionCode::EXT_ERROR => { + ExtError::parse_bytes_by_ref(bytes).map(Self::ExtError) + } + + _ => { + let data = UnknownOption::parse_bytes_by_ref(bytes)?; + Ok(Self::Unknown(code, data)) + } + } + .map(|this| (this, rest)) + } +} + +//--- Building byte strings + +impl BuildBytes for EdnsOption<'_> { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.code().build_bytes(bytes)?; + + let data = match self { + Self::CookieRequest(this) => this.as_bytes(), + Self::Cookie(this) => this.as_bytes(), + Self::ExtError(this) => this.as_bytes(), + Self::Unknown(_, this) => this.as_bytes(), + }; + + bytes = U16::new(data.len() as u16).build_bytes(bytes)?; + bytes = data.build_bytes(bytes)?; + Ok(bytes) + } +} + //----------- OptionCode ----------------------------------------------------- /// An Extended DNS option code. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -237,11 +344,44 @@ pub struct OptionCode { pub code: U16, } +//--- Associated Constants + +impl OptionCode { + const fn new(code: u16) -> Self { + Self { + code: U16::new(code), + } + } + + /// A DNS cookie (request). + pub const COOKIE: Self = Self::new(10); + + /// An extended DNS error. + pub const EXT_ERROR: Self = Self::new(15); +} + +//--- Formatting + +impl fmt::Debug for OptionCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::COOKIE => "OptionCode::COOKIE", + Self::EXT_ERROR => "OptionCode::EXT_ERROR", + _ => { + return f + .debug_tuple("OptionCode") + .field(&self.code.get()) + .finish(); + } + }) + } +} + //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. #[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] -#[repr(C)] +#[repr(transparent)] pub struct UnknownOption { /// The unparsed option data. pub octets: [u8], diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 1be038e45..ebdcc7743 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -174,7 +174,7 @@ impl BuildBytes for RecordData<'_, N> { /// Data for an unknown DNS record type. #[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] -#[repr(C)] +#[repr(transparent)] pub struct UnknownRecordData { /// The unparsed option data. pub octets: [u8], From 7bb6d2c3506748014816b76d81da01aa88876e4f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:27:27 +0100 Subject: [PATCH 071/111] [new_base] Add module 'wire' to replace 'zerocopy' This ended up collecting a lot of small changes as I tried to get things to compile. - All the bytes parsing/building traits have been moved to 'wire'. - The 'wire::ints' module replaces 'U16' and 'U32' from 'zerocopy'. - '{Parse,Split}BytesByRef' now support parsing from '&mut'. - Every derive macro is documented under a re-export in 'wire'. The remaining contents of 'new_base::{build, parse}' might get moved into a shared 'message' module at some point. We'll see. --- Cargo.lock | 30 +- Cargo.toml | 6 - macros/Cargo.toml | 2 +- macros/src/lib.rs | 138 ++++++--- src/new_base/build/builder.rs | 23 +- src/new_base/build/mod.rs | 153 +--------- src/new_base/charstr.rs | 11 +- src/new_base/message.rs | 24 +- src/new_base/mod.rs | 1 + src/new_base/name/label.rs | 6 +- src/new_base/name/reversed.rs | 13 +- src/new_base/parse/mod.rs | 346 +---------------------- src/new_base/question.rs | 7 +- src/new_base/record.rs | 35 +-- src/new_base/serial.rs | 4 +- src/new_base/wire/build.rs | 192 +++++++++++++ src/new_base/wire/ints.rs | 282 +++++++++++++++++++ src/new_base/wire/mod.rs | 81 ++++++ src/new_base/wire/parse.rs | 510 ++++++++++++++++++++++++++++++++++ src/new_edns/cookie.rs | 4 +- src/new_edns/ext_err.rs | 4 +- src/new_edns/mod.rs | 10 +- src/new_rdata/basic.rs | 15 +- src/new_rdata/ipv6.rs | 5 +- src/new_rdata/mod.rs | 10 +- 25 files changed, 1253 insertions(+), 659 deletions(-) create mode 100644 src/new_base/wire/build.rs create mode 100644 src/new_base/wire/ints.rs create mode 100644 src/new_base/wire/mod.rs create mode 100644 src/new_base/wire/parse.rs diff --git a/Cargo.lock b/Cargo.lock index d9833efa5..953edfb0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -279,8 +279,6 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", - "zerocopy 0.8.13", - "zerocopy-derive 0.8.13", ] [[package]] @@ -809,7 +807,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -1196,9 +1194,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1702,16 +1700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67914ab451f3bfd2e69e5e9d2ef3858484e7074d63f204fd166ec391b54de21d" -dependencies = [ - "zerocopy-derive 0.8.13", + "zerocopy-derive", ] [[package]] @@ -1725,17 +1714,6 @@ dependencies = [ "syn", ] -[[package]] -name = "zerocopy-derive" -version = "0.8.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7988d73a4303ca289df03316bc490e934accf371af6bc745393cf3c2c5c4f25d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 041d83731..ebe95a4cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,12 +50,6 @@ tokio-stream = { version = "0.1.1", optional = true } tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-filter"] } -# 'zerocopy' provides simple derives for converting types to and from byte -# representations, along with network-endian integer primitives. These are -# used to define simple elements of DNS messages and their serialization. -zerocopy = "0.8.5" -zerocopy-derive = "0.8.5" - [features] default = ["std", "rand"] diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 263db27af..7060a61eb 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -22,7 +22,7 @@ version = "1.0" [dependencies.syn] version = "2.0" -features = ["visit"] +features = ["full", "visit"] [dependencies.quote] version = "1.0" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 99d209fff..3fa1bc18b 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -47,7 +47,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -58,7 +58,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } @@ -70,10 +70,10 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (Self, & #lifetime [::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { #(let (#init_vars, bytes) = - <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* Ok((#builder, bytes)) } @@ -119,7 +119,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -130,13 +130,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); } @@ -147,12 +147,12 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { if bytes.is_empty() { Ok(#builder) } else { - Err(::domain::new_base::parse::ParseError) + Err(::domain::new_base::wire::ParseError) } } }); @@ -170,13 +170,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { #(let (#init_vars, bytes) = - <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* let #unsized_init_var = - <#unsized_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + <#unsized_ty as ::domain::new_base::wire::ParseBytes<#lifetime>> ::parse_bytes(bytes)?; Ok(#builder) } @@ -217,7 +217,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); + Some(parse_quote!(::domain::new_base::wire::SplitBytesByRef)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -226,7 +226,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytesByRef), + parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } @@ -237,7 +237,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { Ok(( // SAFETY: 'Self' is a 'struct' with no fields, @@ -260,17 +260,17 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { let start = bytes.as_ptr(); #(let (_, bytes) = - <#tys as ::domain::new_base::parse::SplitBytesByRef> + <#tys as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let (last, rest) = - <#unsized_ty as ::domain::new_base::parse::SplitBytesByRef> + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?; let ptr = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -287,6 +287,40 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); + // Define 'split_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&mut Self, &mut [::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let (last, rest) = + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &mut *(ptr as *const Self as *mut Self) }, rest)) + } + }); + Ok(skeleton.into_token_stream()) } @@ -322,7 +356,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); + Some(parse_quote!(::domain::new_base::wire::ParseBytesByRef)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -331,13 +365,13 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytesByRef), + parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytesByRef), + parse_quote!(::domain::new_base::wire::ParseBytesByRef), ); } @@ -348,7 +382,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< &Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { if bytes.is_empty() { // SAFETY: 'Self' is a 'struct' with no fields, @@ -356,7 +390,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // constructed at any address. Ok(unsafe { &*bytes.as_ptr().cast::() }) } else { - Err(::domain::new_base::parse::ParseError) + Err(::domain::new_base::wire::ParseError) } } }); @@ -381,17 +415,17 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< &Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { let start = bytes.as_ptr(); #(let (_, bytes) = - <#tys as ::domain::new_base::parse::SplitBytesByRef> + <#tys as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let last = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::parse_bytes_by_ref(bytes)?; let ptr = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -408,11 +442,45 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); + // Define 'parse_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &mut Self, + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let last = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::parse_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok(unsafe { &mut *(ptr as *const Self as *mut Self) }) + } + }); + // Define 'ptr_with_address()'. let unsized_member = data.unsized_member(); skeleton.contents.stmts.push(parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(&self.#unsized_member, addr) as *const Self } @@ -451,7 +519,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); skeleton.bound = - Some(parse_quote!(::domain::new_base::build::BuildBytes)); + Some(parse_quote!(::domain::new_base::wire::BuildBytes)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -463,7 +531,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::build::BuildBytes), + parse_quote!(::domain::new_base::wire::BuildBytes), ); } @@ -476,9 +544,9 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { mut bytes: & #lifetime mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< & #lifetime mut [::domain::__core::primitive::u8], - ::domain::new_base::build::TruncationError, + ::domain::new_base::wire::TruncationError, > { - #(bytes = <#tys as ::domain::new_base::build::BuildBytes> + #(bytes = <#tys as ::domain::new_base::wire::BuildBytes> ::build_bytes(&self.#members, bytes)?;)* Ok(bytes) } @@ -519,13 +587,13 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::build::AsBytes)); + Some(parse_quote!(::domain::new_base::wire::AsBytes)); // Establish bounds on the fields. for field in data.fields.iter() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::build::AsBytes), + parse_quote!(::domain::new_base::wire::AsBytes), ); } diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 9245b9011..e02da91db 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -6,11 +6,11 @@ use core::{ ptr::{self, NonNull}, }; -use zerocopy::{FromBytes, IntoBytes, SizeError}; - -use crate::new_base::{name::RevName, Header, Message}; - -use super::{BuildBytes, TruncationError}; +use crate::new_base::{ + name::RevName, + wire::{AsBytes, BuildBytes, ParseBytesByRef, TruncationError}, + Header, Message, +}; //----------- Builder -------------------------------------------------------- @@ -82,8 +82,7 @@ impl<'b> Builder<'b> { context: &'b mut BuilderContext, ) -> Self { assert!(buffer.len() >= 12); - let message = Message::mut_from_bytes(buffer) - .map_err(SizeError::from) + let message = Message::parse_bytes_by_mut(buffer) .expect("A 'Message' can fit in 12 bytes"); context.size = 0; context.max_size = message.contents.len(); @@ -156,9 +155,8 @@ impl<'b> Builder<'b> { pub fn message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. let message = unsafe { &*self.message.as_ptr() }; - let message = message.as_bytes(); - Message::ref_from_bytes_with_elems(message, self.commit) - .map_err(SizeError::from) + let message = &message.as_bytes()[..12 + self.commit]; + Message::parse_bytes_by_ref(message) .expect("'message' represents a valid 'Message'") } @@ -170,9 +168,8 @@ impl<'b> Builder<'b> { pub fn cur_message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. let message = unsafe { &*self.message.as_ptr() }; - let message = message.as_bytes(); - Message::ref_from_bytes_with_elems(message, self.context.size) - .map_err(SizeError::from) + let message = &message.as_bytes()[..12 + self.context.size]; + Message::parse_bytes_by_ref(message) .expect("'message' represents a valid 'Message'") } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 548b2d8fd..2faca3c16 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -1,12 +1,10 @@ //! Building DNS messages in the wire format. -use core::fmt; - -use zerocopy::network_endian::{U16, U32}; - mod builder; pub use builder::{Builder, BuilderContext}; +pub use super::wire::TruncationError; + //----------- Message-aware building traits ---------------------------------- /// Building into a DNS message. @@ -41,150 +39,3 @@ impl BuildIntoMessage for [u8] { Ok(()) } } - -//----------- Low-level building traits -------------------------------------- - -/// Serializing into a byte string. -pub trait BuildBytes { - /// Serialize into a byte string. - /// - /// `self` is serialized into a byte string and written to the given - /// buffer. If the buffer is large enough, the whole object is written - /// and the remaining (unmodified) part of the buffer is returned. - /// - /// if the buffer is too small, a [`TruncationError`] is returned (and - /// parts of the buffer may be modified). - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError>; -} - -impl BuildBytes for &T { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - T::build_bytes(*self, bytes) - } -} - -impl BuildBytes for u8 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - if let Some((elem, rest)) = bytes.split_first_mut() { - *elem = *self; - Ok(rest) - } else { - Err(TruncationError) - } - } -} - -impl BuildBytes for str { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for U16 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for U32 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for [T] { - fn build_bytes<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - for elem in self { - bytes = elem.build_bytes(bytes)?; - } - Ok(bytes) - } -} - -impl BuildBytes for [T; N] { - fn build_bytes<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - for elem in self { - bytes = elem.build_bytes(bytes)?; - } - Ok(bytes) - } -} - -/// Interpreting a value as a byte string. -/// -/// # Safety -/// -/// A type `T` can soundly implement [`AsBytes`] if and only if: -/// -/// - It has no padding bytes. -/// - It has no interior mutability. -pub unsafe trait AsBytes { - /// Interpret this value as a sequence of bytes. - /// - /// ## Invariants - /// - /// For the statement `let bytes = this.as_bytes();`, - /// - /// - `bytes.as_ptr() as usize == this as *const _ as usize`. - /// - `bytes.len() == core::mem::size_of_val(this)`. - /// - /// The default implementation automatically satisfies these invariants. - fn as_bytes(&self) -> &[u8] { - // SAFETY: - // - 'Self' has no padding bytes and no interior mutability. - // - Its size in memory is exactly 'size_of_val(self)'. - unsafe { - core::slice::from_raw_parts( - self as *const Self as *const u8, - core::mem::size_of_val(self), - ) - } - } -} - -unsafe impl AsBytes for u8 {} -unsafe impl AsBytes for str {} - -unsafe impl AsBytes for [T] {} -unsafe impl AsBytes for [T; N] {} - -unsafe impl AsBytes for U16 {} -unsafe impl AsBytes for U32 {} - -//----------- TruncationError ------------------------------------------------ - -/// A DNS message did not fit in a buffer. -#[derive(Clone, Debug, PartialEq, Hash)] -pub struct TruncationError; - -//--- Formatting - -impl fmt::Display for TruncationError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("A buffer was too small to fit a DNS message") - } -} diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 2a82e95fa..979e8be20 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -2,13 +2,12 @@ use core::{fmt, ops::Range}; -use zerocopy::IntoBytes; - use super::{ - build::{self, BuildBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, + TruncationError, }, Message, }; diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 3307609bb..9c27d384f 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -2,17 +2,14 @@ use core::fmt; -use zerocopy::network_endian::U16; -use zerocopy_derive::*; - use domain_macros::{AsBytes, *}; +use super::wire::U16; + //----------- Message -------------------------------------------------------- /// A DNS message. -#[derive( - FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, ParseBytesByRef, -)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C, packed)] pub struct Message { /// The message header. @@ -30,11 +27,6 @@ pub struct Message { Clone, Debug, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, @@ -76,11 +68,6 @@ impl fmt::Display for Header { Clone, Default, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, @@ -243,11 +230,6 @@ impl fmt::Display for HeaderFlags { PartialEq, Eq, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 3c2e34068..df8632884 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -32,3 +32,4 @@ pub use serial::Serial; pub mod build; pub mod parse; +pub mod wire; diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 78ef94008..9cb4d1d85 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use domain_macros::AsBytes; -use crate::new_base::parse::{ParseBytes, ParseError, SplitBytes}; +use crate::new_base::wire::{ParseBytes, ParseError, SplitBytes}; //----------- Label ---------------------------------------------------------- @@ -56,8 +56,8 @@ impl<'a> SplitBytes<'a> for &'a Label { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&size, rest) = bytes.split_first().ok_or(ParseError)?; if size < 64 && rest.len() >= size as usize { - let (label, rest) = bytes.split_at(1 + size as usize); - // SAFETY: 'label' begins with a valid length octet. + let (label, rest) = rest.split_at(size as usize); + // SAFETY: 'label' is 'size < 64' bytes in size. Ok((unsafe { Label::from_bytes_unchecked(label) }, rest)) } else { Err(ParseError) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6fae3c0f2..ba7cdb8c6 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -8,13 +8,12 @@ use core::{ ops::{Deref, Range}, }; -use zerocopy::IntoBytes; - use crate::new_base::{ - build::{self, BuildBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, + TruncationError, }, Message, }; @@ -228,7 +227,7 @@ impl RevNameBuf { /// Construct an empty, invalid buffer. fn empty() -> Self { Self { - offset: 0, + offset: 255, buffer: [0; 255], } } diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 32bc3627a..d36dd9543 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,11 +1,6 @@ //! Parsing DNS messages from the wire format. -use core::{fmt, ops::Range}; - -use zerocopy::{ - network_endian::{U16, U32}, - FromBytes, IntoBytes, -}; +use core::ops::Range; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -16,7 +11,12 @@ pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; mod record; pub use record::{ParseRecord, ParseRecords, VisitRecord}; -use super::Message; +pub use super::wire::ParseError; + +use super::{ + wire::{AsBytes, ParseBytesByRef, SplitBytesByRef}, + Message, +}; //----------- Message-aware parsing traits ----------------------------------- @@ -68,335 +68,3 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { T::parse_bytes_by_ref(bytes) } } - -//----------- Low-level parsing traits --------------------------------------- - -/// Parsing from the start of a byte string. -pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { - /// Parse a value of [`Self`] from the start of the byte string. - /// - /// If parsing is successful, the parsed value and the rest of the string - /// are returned. Otherwise, a [`ParseError`] is returned. - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; -} - -/// Parsing from a byte string. -pub trait ParseBytes<'a>: Sized { - /// Parse a value of [`Self`] from the given byte string. - /// - /// If parsing is successful, the parsed value is returned. Otherwise, a - /// [`ParseError`] is returned. - fn parse_bytes(bytes: &'a [u8]) -> Result; -} - -impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - T::split_bytes_by_ref(bytes).map_err(|_| ParseError) - } -} - -impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { - fn parse_bytes(bytes: &'a [u8]) -> Result { - T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) - } -} - -impl<'a> SplitBytes<'a> for u8 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) - } -} - -impl<'a> ParseBytes<'a> for u8 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let [result] = bytes else { - return Err(ParseError); - }; - - Ok(*result) - } -} - -impl<'a> SplitBytes<'a> for U16 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - Self::read_from_prefix(bytes).map_err(Into::into) - } -} - -impl<'a> ParseBytes<'a> for U16 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - Self::read_from_bytes(bytes).map_err(Into::into) - } -} - -impl<'a> SplitBytes<'a> for U32 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - Self::read_from_prefix(bytes).map_err(Into::into) - } -} - -impl<'a> ParseBytes<'a> for U32 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - Self::read_from_bytes(bytes).map_err(Into::into) - } -} - -/// Zero-copy parsing from the start of a byte string. -/// -/// This is an extension of [`ParseBytesByRef`] for types which can determine -/// their own length when parsing. It is usually implemented by [`Sized`] -/// types (where the length is just the size of the type), although it can be -/// sometimes implemented by unsized types. -/// -/// # Safety -/// -/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants -/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is -/// considered to cause undefined behaviour. -/// -/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() -/// -/// Note that [`ParseBytesByRef`], required by this trait, also has several -/// invariants that need to be considered with care. -pub unsafe trait SplitBytesByRef: ParseBytesByRef { - /// Interpret a byte string as an instance of [`Self`]. - /// - /// The byte string will be validated and re-interpreted as a reference to - /// [`Self`]. The length of [`Self`] will be determined, possibly based - /// on the contents (but not the length!) of the input, and the remaining - /// bytes will be returned. If the input does not begin with a valid - /// instance of [`Self`], a [`ParseError`] is returned. - /// - /// ## Invariants - /// - /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, - /// - /// - `bytes.as_ptr() == this as *const T as *const u8`. - /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. - /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. - fn split_bytes_by_ref(bytes: &[u8]) - -> Result<(&Self, &[u8]), ParseError>; -} - -/// Zero-copy parsing from a byte string. -/// -/// # Safety -/// -/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants -/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An -/// incorrect implementation is considered to cause undefined behaviour. -/// -/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() -/// [`ptr_with_address()`]: Self::ptr_with_address() -/// -/// Implementing types must also have no alignment (i.e. a valid instance of -/// [`Self`] can occur at any address). This eliminates the possibility of -/// padding bytes, even when [`Self`] is part of a larger aggregate type. -pub unsafe trait ParseBytesByRef { - /// Interpret a byte string as an instance of [`Self`]. - /// - /// The byte string will be validated and re-interpreted as a reference to - /// [`Self`]. The whole byte string will be used. If the input is not a - /// valid instance of [`Self`], a [`ParseError`] is returned. - /// - /// ## Invariants - /// - /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, - /// - /// - `bytes.as_ptr() == this as *const T as *const u8`. - /// - `bytes.len() == core::mem::size_of_val(this)`. - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; - - /// Change the address of a pointer to [`Self`]. - /// - /// When [`Self`] is used as the last field in a type that also implements - /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or - /// reference) to it may include additional metadata. This metadata is - /// included verbatim in any reference/pointer to the containing type. - /// - /// When the containing type implements [`ParseBytesByRef`], it needs to - /// construct a reference/pointer to itself, which includes this metadata. - /// Rust does not currently offer a general way to extract this metadata - /// or pair it with another address, so this function is necessary. The - /// caller can construct a reference to [`Self`], then change its address - /// to point to the containing type, then cast that pointer to the right - /// type. - /// - /// # Implementing - /// - /// Most users will derive [`ParseBytesByRef`] and so don't need to worry - /// about this. For manual implementations: - /// - /// In the future, an adequate default implementation for this function - /// may be provided. Until then, it should be implemented using one of - /// the following expressions: - /// - /// ```ignore - /// fn ptr_with_address( - /// &self, - /// addr: *const (), - /// ) -> *const Self { - /// // If 'Self' is Sized: - /// addr.cast::() - /// - /// // If 'Self' is an aggregate whose last field is 'last': - /// self.last.ptr_with_address(addr) as *const Self - /// } - /// ``` - /// - /// # Invariants - /// - /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: - /// - /// - `result as usize == addr as usize`. - /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. - fn ptr_with_address(&self, addr: *const ()) -> *const Self; -} - -unsafe impl SplitBytesByRef for u8 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - bytes.split_first().ok_or(ParseError) - } -} - -unsafe impl ParseBytesByRef for u8 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - let [result] = bytes else { - return Err(ParseError); - }; - - Ok(result) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl SplitBytesByRef for U16 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - Self::ref_from_prefix(bytes).map_err(Into::into) - } -} - -unsafe impl ParseBytesByRef for U16 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Self::ref_from_bytes(bytes).map_err(Into::into) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl SplitBytesByRef for U32 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - Self::ref_from_prefix(bytes).map_err(Into::into) - } -} - -unsafe impl ParseBytesByRef for U32 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Self::ref_from_bytes(bytes).map_err(Into::into) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl ParseBytesByRef for [u8] { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Ok(bytes) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - core::ptr::slice_from_raw_parts(addr.cast(), self.len()) - } -} - -unsafe impl ParseBytesByRef for str { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - core::str::from_utf8(bytes).map_err(|_| ParseError) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - // NOTE: The Rust Reference indicates that 'str' has the same layout - // as '[u8]' [1]. This is also the most natural layout for it. Since - // there's no way to construct a '*const str' from raw parts, we will - // just construct a raw slice and transmute it. - // - // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout - - self.as_bytes().ptr_with_address(addr) as *const Self - } -} - -unsafe impl SplitBytesByRef for [T; N] { - fn split_bytes_by_ref( - mut bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - let start = bytes.as_ptr(); - for _ in 0..N { - (_, bytes) = T::split_bytes_by_ref(bytes)?; - } - - // SAFETY: - // - 'T::split_bytes_by_ref()' was called 'N' times on successive - // positions, thus the original 'bytes' starts with 'N' instances - // of 'T' (even if 'T' is a ZST and so all instances overlap). - // - 'N' consecutive 'T's have the same layout as '[T; N]'. - // - Thus it is safe to cast 'start' to '[T; N]'. - // - The referenced data has the same lifetime as the output. - Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) - } -} - -unsafe impl ParseBytesByRef for [T; N] { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - let (this, rest) = Self::split_bytes_by_ref(bytes)?; - if rest.is_empty() { - Ok(this) - } else { - Err(ParseError) - } - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -//----------- ParseError ----------------------------------------------------- - -/// A DNS message parsing error. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct ParseError; - -//--- Formatting - -impl fmt::Display for ParseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("DNS data could not be parsed from the wire format") - } -} - -//--- Conversion from 'zerocopy' errors - -impl From> for ParseError { - fn from(_: zerocopy::ConvertError) -> Self { - Self - } -} - -impl From> for ParseError { - fn from(_: zerocopy::SizeError) -> Self { - Self - } -} diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 4e93951aa..0dad0910a 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,14 +2,13 @@ use core::ops::Range; -use zerocopy::network_endian::U16; - use domain_macros::*; use super::{ - build::{self, AsBytes, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage}, name::RevNameBuf, - parse::{ParseError, ParseFromMessage, SplitFromMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{AsBytes, ParseError, TruncationError, U16}, Message, }; diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 391b95dee..2d84e0934 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,19 +5,13 @@ use core::{ ops::{Deref, Range}, }; -use zerocopy::{ - network_endian::{U16, U32}, - FromBytes, IntoBytes, -}; - -use domain_macros::*; - use super::{ - build::{self, AsBytes, BuildBytes, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage}, name::RevNameBuf, - parse::{ - ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, - SplitBytes, SplitFromMessage, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, + SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, Message, }; @@ -160,9 +154,13 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; + let (size, rest) = U16::split_bytes(rest)?; let size: usize = size.get().into(); - let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + if rest.len() < size { + return Err(ParseError); + } + + let (rdata, rest) = rest.split_at(size); let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) @@ -179,10 +177,13 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; + let (size, rest) = U16::split_bytes(rest)?; let size: usize = size.get().into(); - let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - let rdata = D::parse_record_data_bytes(rdata, rtype)?; + if rest.len() != size { + return Err(ParseError); + } + + let rdata = D::parse_record_data_bytes(rest, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -205,7 +206,7 @@ where bytes = self.ttl.as_bytes().build_bytes(bytes)?; let (size, bytes) = - ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; + U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; let bytes_len = bytes.len(); let rest = self.rdata.build_bytes(bytes)?; diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index 4258c4b22..af0e4a1a1 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,10 +8,10 @@ use core::{ ops::{Add, AddAssign}, }; -use zerocopy::network_endian::U32; - use domain_macros::*; +use super::wire::U32; + //----------- Serial --------------------------------------------------------- /// A serial number. diff --git a/src/new_base/wire/build.rs b/src/new_base/wire/build.rs new file mode 100644 index 000000000..1b67a4d40 --- /dev/null +++ b/src/new_base/wire/build.rs @@ -0,0 +1,192 @@ +//! Building data in the basic network format. + +use core::fmt; + +//----------- BuildBytes ----------------------------------------------------- + +/// Serializing into a byte string. +pub trait BuildBytes { + /// Serialize into a byte string. + /// + /// `self` is serialized into a byte string and written to the given + /// buffer. If the buffer is large enough, the whole object is written + /// and the remaining (unmodified) part of the buffer is returned. + /// + /// if the buffer is too small, a [`TruncationError`] is returned (and + /// parts of the buffer may be modified). + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError>; +} + +impl BuildBytes for &T { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + T::build_bytes(*self, bytes) + } +} + +impl BuildBytes for u8 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if let Some((elem, rest)) = bytes.split_first_mut() { + *elem = *self; + Ok(rest) + } else { + Err(TruncationError) + } + } +} + +impl BuildBytes for str { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for [T] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +impl BuildBytes for [T; N] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +/// Deriving [`BuildBytes`] automatically. +/// +/// [`BuildBytes`] can be derived on `struct`s (not `enum`s or `union`s). The +/// generated implementation will call [`build_bytes()`] with each field, in +/// the order they are declared. The trait implementation will be bounded by +/// the type of every field implementing [`BuildBytes`]. +/// +/// [`build_bytes()`]: BuildBytes::build_bytes() +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{BuildBytes, U32, TruncationError}; +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(BuildBytes)': +/// impl BuildBytes for Foo +/// where Bar: BuildBytes { +/// fn build_bytes<'bytes>( +/// &self, +/// mut bytes: &'bytes mut [u8], +/// ) -> Result<&'bytes mut [u8], TruncationError> { +/// bytes = self.a.build_bytes(bytes)?; +/// bytes = self.b.build_bytes(bytes)?; +/// Ok(bytes) +/// } +/// } +/// ``` +pub use domain_macros::BuildBytes; + +//----------- AsBytes -------------------------------------------------------- + +/// Interpreting a value as a byte string. +/// +/// # Safety +/// +/// A type `T` can soundly implement [`AsBytes`] if and only if: +/// +/// - It has no padding bytes. +/// - It has no interior mutability. +pub unsafe trait AsBytes { + /// Interpret this value as a sequence of bytes. + /// + /// ## Invariants + /// + /// For the statement `let bytes = this.as_bytes();`, + /// + /// - `bytes.as_ptr() as usize == this as *const _ as usize`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + /// + /// The default implementation automatically satisfies these invariants. + fn as_bytes(&self) -> &[u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} + +unsafe impl AsBytes for u8 {} +unsafe impl AsBytes for str {} + +unsafe impl AsBytes for [T] {} +unsafe impl AsBytes for [T; N] {} + +/// Deriving [`AsBytes`] automatically. +/// +/// [`AsBytes`] can be derived on `struct`s (not `enum`s or `union`s), where a +/// fixed memory layout (`repr(C)` or `repr(transparent)`) is used. Every +/// field must implement [`AsBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{AsBytes, U32}; +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(AsBytes)': +/// unsafe impl AsBytes for Foo +/// where Bar: AsBytes { +/// // The default implementation of 'as_bytes()' works. +/// } +/// ``` +pub use domain_macros::AsBytes; + +//----------- TruncationError ------------------------------------------------ + +/// A DNS message did not fit in a buffer. +#[derive(Clone, Debug, PartialEq, Hash)] +pub struct TruncationError; + +//--- Formatting + +impl fmt::Display for TruncationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A buffer was too small to fit a DNS message") + } +} diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs new file mode 100644 index 000000000..3d11f45e4 --- /dev/null +++ b/src/new_base/wire/ints.rs @@ -0,0 +1,282 @@ +//! Integer primitives for the DNS wire format. + +use core::{ + cmp::Ordering, + fmt, + ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, + BitXorAssign, Not, Sub, SubAssign, + }, +}; + +use domain_macros::*; + +use super::{ + ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, +}; + +//----------- define_int ----------------------------------------------------- + +/// Define a network endianness integer primitive. +macro_rules! define_int { + { $( + $(#[$docs:meta])* + $name:ident([u8; $size:literal]) = $base:ident; + )* } => { $( + $(#[$docs])* + #[derive( + Copy, + Clone, + Default, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytesByRef, + SplitBytesByRef, + )] + #[repr(transparent)] + pub struct $name([u8; $size]); + + //--- Conversion to and from integer primitive types + + impl $name { + /// Convert an integer to network endianness. + pub const fn new(value: $base) -> Self { + Self(value.to_be_bytes()) + } + + /// Convert an integer from network endianness. + pub const fn get(self) -> $base { + <$base>::from_be_bytes(self.0) + } + } + + impl From<$base> for $name { + fn from(value: $base) -> Self { + Self::new(value) + } + } + + impl From<$name> for $base { + fn from(value: $name) -> Self { + value.get() + } + } + + //--- Parsing from bytes + + impl<'b> ParseBytes<'b> for $name { + fn parse_bytes(bytes: &'b [u8]) -> Result { + Self::parse_bytes_by_ref(bytes).copied() + } + } + + impl<'b> SplitBytes<'b> for $name { + fn split_bytes( + bytes: &'b [u8], + ) -> Result<(Self, &'b [u8]), ParseError> { + Self::split_bytes_by_ref(bytes) + .map(|(&this, rest)| (this, rest)) + } + } + + //--- Comparison + + impl PartialEq<$base> for $name { + fn eq(&self, other: &$base) -> bool { + self.get() == *other + } + } + + impl PartialOrd<$base> for $name { + fn partial_cmp(&self, other: &$base) -> Option { + self.get().partial_cmp(other) + } + } + + impl PartialOrd for $name { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for $name { + fn cmp(&self, other: &Self) -> Ordering { + self.get().cmp(&other.get()) + } + } + + //--- Formatting + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple(stringify!($name)).field(&self.get()).finish() + } + } + + //--- Arithmetic + + impl Add for $name { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.get() + rhs.get()) + } + } + + impl AddAssign for $name { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + + impl Add<$base> for $name { + type Output = Self; + + fn add(self, rhs: $base) -> Self::Output { + Self::new(self.get() + rhs) + } + } + + impl AddAssign<$base> for $name { + fn add_assign(&mut self, rhs: $base) { + *self = *self + rhs; + } + } + + impl Sub for $name { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.get() - rhs.get()) + } + } + + impl SubAssign for $name { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + + impl Sub<$base> for $name { + type Output = Self; + + fn sub(self, rhs: $base) -> Self::Output { + Self::new(self.get() - rhs) + } + } + + impl SubAssign<$base> for $name { + fn sub_assign(&mut self, rhs: $base) { + *self = *self - rhs; + } + } + + impl Not for $name { + type Output = Self; + + fn not(self) -> Self::Output { + Self::new(!self.get()) + } + } + + //--- Bitwise operations + + impl BitAnd for $name { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self::new(self.get() & rhs.get()) + } + } + + impl BitAndAssign for $name { + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } + } + + impl BitAnd<$base> for $name { + type Output = Self; + + fn bitand(self, rhs: $base) -> Self::Output { + Self::new(self.get() & rhs) + } + } + + impl BitAndAssign<$base> for $name { + fn bitand_assign(&mut self, rhs: $base) { + *self = *self & rhs; + } + } + + impl BitOr for $name { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self::new(self.get() | rhs.get()) + } + } + + impl BitOrAssign for $name { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } + } + + impl BitOr<$base> for $name { + type Output = Self; + + fn bitor(self, rhs: $base) -> Self::Output { + Self::new(self.get() | rhs) + } + } + + impl BitOrAssign<$base> for $name { + fn bitor_assign(&mut self, rhs: $base) { + *self = *self | rhs; + } + } + + impl BitXor for $name { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self::new(self.get() ^ rhs.get()) + } + } + + impl BitXorAssign for $name { + fn bitxor_assign(&mut self, rhs: Self) { + *self = *self ^ rhs; + } + } + + impl BitXor<$base> for $name { + type Output = Self; + + fn bitxor(self, rhs: $base) -> Self::Output { + Self::new(self.get() ^ rhs) + } + } + + impl BitXorAssign<$base> for $name { + fn bitxor_assign(&mut self, rhs: $base) { + *self = *self ^ rhs; + } + } + )* }; +} + +define_int! { + /// An unsigned 16-bit integer in network endianness. + U16([u8; 2]) = u16; + + /// An unsigned 32-bit integer in network endianness. + U32([u8; 4]) = u32; + + /// An unsigned 64-bit integer in network endianness. + U64([u8; 8]) = u64; +} diff --git a/src/new_base/wire/mod.rs b/src/new_base/wire/mod.rs new file mode 100644 index 000000000..4d5be5c25 --- /dev/null +++ b/src/new_base/wire/mod.rs @@ -0,0 +1,81 @@ +//! The basic wire format of network protocols. +//! +//! This is a low-level module providing simple and efficient mechanisms to +//! parse data from and build data into byte sequences. It takes inspiration +//! from the [zerocopy] crate, but 1) is significantly simpler, 2) has simple +//! requirements for its `derive` macros, and 3) supports parsing out-of-place +//! (i.e. non-zero-copy). +//! +//! [zerocopy]: https://github.com/google/zerocopy +//! +//! # Design +//! +//! When a type is defined to represent a component of a network packet, its +//! internal structure should match the structure of its wire format. Here's +//! an example of a question in a DNS record: +//! +//! ``` +//! # use domain::new_base::{QType, QClass, wire::*}; +//! #[derive(BuildBytes, ParseBytes, SplitBytes)] +//! pub struct Question { +//! /// The domain name being requested. +//! pub qname: N, +//! +//! /// The type of the requested records. +//! pub qtype: QType, +//! +//! /// The class of the requested records. +//! pub qclass: QClass, +//! } +//! ``` +//! +//! This exactly matches the structure of a question on the wire -- the QNAME, +//! the QTYPE, and the QCLASS. This allows the definition of the type to also +//! specify the wire format concisely. +//! +//! Now, this type can be read from and written to bytes very easily: +//! +//! ``` +//! # use domain::new_base::{Question, name::RevNameBuf, wire::*}; +//! // { qname: "org.", qtype: A, qclass: IN } +//! let bytes = [3, 111, 114, 103, 0, 0, 1, 0, 1]; +//! let question = Question::::parse_bytes(&bytes).unwrap(); +//! let mut duplicate = [0u8; 9]; +//! let rest = question.build_bytes(&mut duplicate).unwrap(); +//! assert_eq!(*rest, []); +//! assert_eq!(bytes, duplicate); +//! ``` +//! +//! There are three important traits to consider: +//! +//! - [`ParseBytes`]: For interpreting an entire byte string as an instance of +//! the target type. +//! +//! - [`SplitBytes`]: For interpreting _the start_ of a byte string as an +//! instance of the target type. +//! +//! - [`BuildBytes`]: For serializing an object and writing it to the _start_ +//! of a byte string. +//! +//! These operate by value, and copy (some) data from the input. However, +//! there are also zero-copy versions of these traits, which are more +//! efficient (but not always applicable): +//! +//! - [`ParseBytesByRef`]: Like [`ParseBytes`], but transmutes the byte string +//! into an instance of the target type in place. +//! +//! - [`SplitBytesByRef`]: Like [`SplitBytes`], but transmutes the byte string +//! into an instance of the target type in place. +//! +//! - [`AsBytes`]: Allows interpreting an object as a byte string in place. + +mod build; +pub use build::{AsBytes, BuildBytes, TruncationError}; + +mod parse; +pub use parse::{ + ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, +}; + +mod ints; +pub use ints::{U16, U32, U64}; diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs new file mode 100644 index 000000000..3ee5d44a1 --- /dev/null +++ b/src/new_base/wire/parse.rs @@ -0,0 +1,510 @@ +//! Parsing bytes in the basic network format. + +use core::fmt; + +//----------- ParseBytes ----------------------------------------------------- + +/// Parsing from a byte string. +pub trait ParseBytes<'a>: Sized { + /// Parse a value of [`Self`] from the given byte string. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_bytes(bytes: &'a [u8]) -> Result; +} + +impl<'a> ParseBytes<'a> for u8 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let [result] = bytes else { + return Err(ParseError); + }; + + Ok(*result) + } +} + +impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { + fn parse_bytes(bytes: &'a [u8]) -> Result { + T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +/// Deriving [`ParseBytes`] automatically. +/// +/// [`ParseBytes`] can be derived on `struct`s (not `enum`s or `union`s). All +/// fields except the last must implement [`SplitBytes`], while the last field +/// only needs to implement [`ParseBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytes, SplitBytes, U32, ParseError}; +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(ParseBytes)': +/// impl<'bytes, T> ParseBytes<'bytes> for Foo +/// where +/// U32: SplitBytes<'bytes>, +/// Bar: ParseBytes<'bytes>, +/// { +/// fn parse_bytes( +/// bytes: &'bytes [u8], +/// ) -> Result { +/// let (field_a, bytes) = U32::split_bytes(bytes)?; +/// let field_b = >::parse_bytes(bytes)?; +/// Ok(Self { a: field_a, b: field_b }) +/// } +/// } +/// ``` +pub use domain_macros::ParseBytes; + +//----------- SplitBytes ----------------------------------------------------- + +/// Parsing from the start of a byte string. +pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { + /// Parse a value of [`Self`] from the start of the byte string. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; +} + +impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::split_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +impl<'a> SplitBytes<'a> for u8 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) + } +} + +/// Deriving [`SplitBytes`] automatically. +/// +/// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All +/// fields except the last must implement [`SplitBytes`], while the last field +/// only needs to implement [`SplitBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytes, SplitBytes, U32, ParseError}; +/// #[derive(ParseBytes)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(SplitBytes)': +/// impl<'bytes, T> SplitBytes<'bytes> for Foo +/// where +/// U32: SplitBytes<'bytes>, +/// Bar: SplitBytes<'bytes>, +/// { +/// fn split_bytes( +/// bytes: &'bytes [u8], +/// ) -> Result<(Self, &'bytes [u8]), ParseError> { +/// let (field_a, bytes) = U32::split_bytes(bytes)?; +/// let (field_b, bytes) = >::split_bytes(bytes)?; +/// Ok((Self { a: field_a, b: field_b }, bytes)) +/// } +/// } +/// ``` +pub use domain_macros::SplitBytes; + +//----------- ParseBytesByRef ------------------------------------------------ + +/// Zero-copy parsing from a byte string. +/// +/// # Safety +/// +/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants +/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An +/// incorrect implementation is considered to cause undefined behaviour. +/// +/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() +/// [`ptr_with_address()`]: Self::ptr_with_address() +/// +/// Implementing types must also have no alignment (i.e. a valid instance of +/// [`Self`] can occur at any address). This eliminates the possibility of +/// padding bytes, even when [`Self`] is part of a larger aggregate type. +pub unsafe trait ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; + + /// Interpret a byte string as an instance of [`Self`], mutably. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &mut T = T::parse_bytes_by_mut(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError>; + + /// Change the address of a pointer to [`Self`]. + /// + /// When [`Self`] is used as the last field in a type that also implements + /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// reference) to it may include additional metadata. This metadata is + /// included verbatim in any reference/pointer to the containing type. + /// + /// When the containing type implements [`ParseBytesByRef`], it needs to + /// construct a reference/pointer to itself, which includes this metadata. + /// Rust does not currently offer a general way to extract this metadata + /// or pair it with another address, so this function is necessary. The + /// caller can construct a reference to [`Self`], then change its address + /// to point to the containing type, then cast that pointer to the right + /// type. + /// + /// # Implementing + /// + /// Most users will derive [`ParseBytesByRef`] and so don't need to worry + /// about this. For manual implementations: + /// + /// In the future, an adequate default implementation for this function + /// may be provided. Until then, it should be implemented using one of + /// the following expressions: + /// + /// ```ignore + /// fn ptr_with_address( + /// &self, + /// addr: *const (), + /// ) -> *const Self { + /// // If 'Self' is Sized: + /// addr.cast::() + /// + /// // If 'Self' is an aggregate whose last field is 'last': + /// self.last.ptr_with_address(addr) as *const Self + /// } + /// ``` + /// + /// # Invariants + /// + /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: + /// + /// - `result as usize == addr as usize`. + /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. + fn ptr_with_address(&self, addr: *const ()) -> *const Self; +} + +unsafe impl ParseBytesByRef for u8 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + if let [result] = bytes { + Ok(result) + } else { + Err(ParseError) + } + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + if let [result] = bytes { + Ok(result) + } else { + Err(ParseError) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl ParseBytesByRef for [u8] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Ok(bytes) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + Ok(bytes) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + core::ptr::slice_from_raw_parts(addr.cast(), self.len()) + } +} + +unsafe impl ParseBytesByRef for str { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + core::str::from_utf8(bytes).map_err(|_| ParseError) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + core::str::from_utf8_mut(bytes).map_err(|_| ParseError) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + // NOTE: The Rust Reference indicates that 'str' has the same layout + // as '[u8]' [1]. This is also the most natural layout for it. Since + // there's no way to construct a '*const str' from raw parts, we will + // just construct a raw slice and transmute it. + // + // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout + + self.as_bytes().ptr_with_address(addr) as *const Self + } +} + +unsafe impl ParseBytesByRef for [T; N] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let (this, rest) = Self::split_bytes_by_ref(bytes)?; + if rest.is_empty() { + Ok(this) + } else { + Err(ParseError) + } + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + let (this, rest) = Self::split_bytes_by_mut(bytes)?; + if rest.is_empty() { + Ok(this) + } else { + Err(ParseError) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +/// Deriving [`ParseBytesByRef`] automatically. +/// +/// [`ParseBytesByRef`] can be derived on `struct`s (not `enum`s or `union`s), +/// where a fixed memory layout (`repr(C)` or `repr(transparent)`) is used. +/// All fields except the last must implement [`SplitBytesByRef`], while the +/// last field only needs to implement [`ParseBytesByRef`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytesByRef, SplitBytesByRef, U32, ParseError}; +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(ParseBytesByRef)': +/// unsafe impl ParseBytesByRef for Foo +/// where Bar: ParseBytesByRef { +/// fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let last = >::parse_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok(unsafe { &*(this as *const Self) }) +/// } +/// +/// fn parse_bytes_by_mut( +/// bytes: &mut [u8], +/// ) -> Result<&mut Self, ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let last = >::parse_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok(unsafe { &mut *(this as *const Self as *mut Self) }) +/// } +/// +/// fn ptr_with_address(&self, addr: *const ()) -> *const Self { +/// self.b.ptr_with_address(addr) as *const Self +/// } +/// } +/// ``` +pub use domain_macros::ParseBytesByRef; + +//----------- SplitBytesByRef ------------------------------------------------ + +/// Zero-copy parsing from the start of a byte string. +/// +/// This is an extension of [`ParseBytesByRef`] for types which can determine +/// their own length when parsing. It is usually implemented by [`Sized`] +/// types (where the length is just the size of the type), although it can be +/// sometimes implemented by unsized types. +/// +/// # Safety +/// +/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants +/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() +/// +/// Note that [`ParseBytesByRef`], required by this trait, also has several +/// invariants that need to be considered with care. +pub unsafe trait SplitBytesByRef: ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`], mutably. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_ref(bytes: &[u8]) + -> Result<(&Self, &[u8]), ParseError>; + + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_mut(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError>; +} + +unsafe impl SplitBytesByRef for u8 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + bytes.split_first().ok_or(ParseError) + } + + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + bytes.split_first_mut().ok_or(ParseError) + } +} + +unsafe impl SplitBytesByRef for [T; N] { + fn split_bytes_by_ref( + mut bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + let start = bytes.as_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_ref(bytes)?; + } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) + } + + fn split_bytes_by_mut( + mut bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + let start = bytes.as_mut_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_mut(bytes)?; + } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &mut *start.cast::<[T; N]>() }, bytes)) + } +} + +/// Deriving [`SplitBytesByRef`] automatically. +/// +/// [`SplitBytesByRef`] can be derived on `struct`s (not `enum`s or `union`s), +/// where a fixed memory layout (`repr(C)` or `repr(transparent)`) is used. +/// All fields must implement [`SplitBytesByRef`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytesByRef, SplitBytesByRef, U32, ParseError}; +/// #[derive(ParseBytesByRef)] +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(SplitBytesByRef)': +/// unsafe impl SplitBytesByRef for Foo +/// where Bar: SplitBytesByRef { +/// fn split_bytes_by_ref( +/// bytes: &[u8], +/// ) -> Result<(&Self, &[u8]), ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let (last, bytes) = >::split_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok((unsafe { &*(this as *const Self) }, bytes)) +/// } +/// +/// fn split_bytes_by_mut( +/// bytes: &mut [u8], +/// ) -> Result<(&mut Self, &mut [u8]), ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_mut(bytes)?; +/// let (last, bytes) = >::split_bytes_by_mut(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok((unsafe { &mut *(this as *const Self as *mut Self) }, bytes)) +/// } +/// } +/// ``` +pub use domain_macros::SplitBytesByRef; + +//----------- ParseError ----------------------------------------------------- + +/// A DNS message parsing error. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ParseError; + +//--- Formatting + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("DNS data could not be parsed from the wire format") + } +} diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 1a815e615..36d96a4f4 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -18,7 +18,7 @@ use domain_macros::*; use crate::new_base::Serial; #[cfg(all(feature = "std", feature = "siphasher"))] -use crate::new_base::build::{AsBytes, TruncationError}; +use crate::new_base::wire::{AsBytes, TruncationError}; //----------- CookieRequest -------------------------------------------------- @@ -71,7 +71,7 @@ impl CookieRequest { use siphasher::sip::SipHasher24; - use crate::new_base::build::BuildBytes; + use crate::new_base::wire::BuildBytes; // Build and hash the cookie simultaneously. let mut hasher = SipHasher24::new_with_key(secret); diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs index 6858aa713..7613afd8e 100644 --- a/src/new_edns/ext_err.rs +++ b/src/new_edns/ext_err.rs @@ -6,12 +6,12 @@ use core::fmt; use domain_macros::*; -use zerocopy::network_endian::U16; +use crate::new_base::wire::U16; //----------- ExtError ------------------------------------------------------- /// An extended DNS error. -#[derive(AsBytes, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct ExtError { /// The error code. diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 3b3ab2a80..781224360 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,16 +4,14 @@ use core::{fmt, ops::Range}; -use zerocopy::{network_endian::U16, IntoBytes}; - use domain_macros::*; use crate::{ new_base::{ - build::{AsBytes, BuildBytes, TruncationError}, - parse::{ - ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, - SplitBytes, SplitFromMessage, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, + SplitBytes, TruncationError, U16, }, Message, }, diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index d9e8829ac..14c7cdc9f 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,15 +10,14 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use zerocopy::network_endian::{U16, U32}; - use domain_macros::*; use crate::new_base::{ - build::{self, AsBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, ParseBytes, ParseError, SplitBytes, TruncationError, U16, + U32, }, CharStr, Message, Serial, }; @@ -386,8 +385,6 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { message: &'a Message, range: Range, ) -> Result { - use zerocopy::IntoBytes; - message .as_bytes() .get(range) @@ -501,8 +498,6 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { message: &'a Message, range: Range, ) -> Result { - use zerocopy::IntoBytes; - message .as_bytes() .get(range) diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index fb3f9d30e..788a1ca97 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -10,8 +10,9 @@ use std::net::Ipv6Addr; use domain_macros::*; -use crate::new_base::build::{ - self, AsBytes, BuildIntoMessage, TruncationError, +use crate::new_base::{ + build::{self, BuildIntoMessage}, + wire::{AsBytes, TruncationError}, }; //----------- Aaaa ----------------------------------------------------------- diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index ebdcc7743..23be1e18f 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -5,11 +5,9 @@ use core::ops::Range; use domain_macros::*; use crate::new_base::{ - build::{BuildBytes, BuildIntoMessage, Builder, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, - }, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, ParseRecordData, RType, }; @@ -131,7 +129,7 @@ where impl BuildIntoMessage for RecordData<'_, N> { fn build_into_message( &self, - builder: Builder<'_>, + builder: build::Builder<'_>, ) -> Result<(), TruncationError> { match self { Self::A(r) => r.build_into_message(builder), From af13cf14cb31684770205eb2f83ed4b79edbddf8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:34:51 +0100 Subject: [PATCH 072/111] [new_base] Correct docs for build traits --- src/new_base/build/mod.rs | 5 ++--- src/new_base/wire/build.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 2faca3c16..35752d2e2 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -11,9 +11,8 @@ pub use super::wire::TruncationError; pub trait BuildIntoMessage { // Append this value to the DNS message. /// - /// If the byte string is long enough to fit the message, it is appended - /// using the given message builder and committed. Otherwise, a - /// [`TruncationError`] is returned. + /// If the builder has enough capacity to fit the message, it is appended + /// and committed. Otherwise, a [`TruncationError`] is returned. fn build_into_message( &self, builder: Builder<'_>, diff --git a/src/new_base/wire/build.rs b/src/new_base/wire/build.rs index 1b67a4d40..88b6a44b3 100644 --- a/src/new_base/wire/build.rs +++ b/src/new_base/wire/build.rs @@ -12,7 +12,7 @@ pub trait BuildBytes { /// buffer. If the buffer is large enough, the whole object is written /// and the remaining (unmodified) part of the buffer is returned. /// - /// if the buffer is too small, a [`TruncationError`] is returned (and + /// If the buffer is too small, a [`TruncationError`] is returned (and /// parts of the buffer may be modified). fn build_bytes<'b>( &self, From 8daf6b502b7c172fd9cccce55a29026f12f1e566 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:42:25 +0100 Subject: [PATCH 073/111] [macros] Avoid glob imports where possible --- macros/src/data.rs | 2 +- macros/src/impls.rs | 58 ++++++++++++++++++++++++++------------------- macros/src/repr.rs | 12 +++++++--- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/macros/src/data.rs b/macros/src/data.rs index 6a0788b3e..ee0c52baf 100644 --- a/macros/src/data.rs +++ b/macros/src/data.rs @@ -4,7 +4,7 @@ use std::ops::Deref; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use syn::{spanned::Spanned, *}; +use syn::{spanned::Spanned, Field, Fields, Ident, Index, Member, Token}; //----------- Struct --------------------------------------------------------- diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 2d9724f0e..4c0971998 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -2,7 +2,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; -use syn::{punctuated::Punctuated, visit::Visit, *}; +use syn::{ + punctuated::Punctuated, visit::Visit, ConstParam, GenericArgument, + GenericParam, Ident, Lifetime, LifetimeParam, Token, TypeParam, + TypeParamBound, WhereClause, WherePredicate, +}; //----------- ImplSkeleton --------------------------------------------------- @@ -21,24 +25,24 @@ pub struct ImplSkeleton { pub unsafety: Option, /// The trait being implemented. - pub bound: Option, + pub bound: Option, /// The type being implemented on. - pub subject: Path, + pub subject: syn::Path, /// The where clause of the `impl` block. pub where_clause: WhereClause, /// The contents of the `impl`. - pub contents: Block, + pub contents: syn::Block, /// A `const` block for asserting requirements. - pub requirements: Block, + pub requirements: syn::Block, } impl ImplSkeleton { /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. - pub fn new(input: &DeriveInput, unsafety: bool) -> Self { + pub fn new(input: &syn::DeriveInput, unsafety: bool) -> Self { let mut lifetimes = Vec::new(); let mut types = Vec::new(); let mut consts = Vec::new(); @@ -55,13 +59,13 @@ impl ImplSkeleton { GenericParam::Type(value) => { types.push(value.clone()); let id = value.ident.clone(); - let id = TypePath { + let id = syn::TypePath { qself: None, - path: Path { + path: syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: id, - arguments: PathArguments::None, + arguments: syn::PathArguments::None, }] .into_iter() .collect(), @@ -73,13 +77,13 @@ impl ImplSkeleton { GenericParam::Const(value) => { consts.push(value.clone()); let id = value.ident.clone(); - let id = TypePath { + let id = syn::TypePath { qself: None, - path: Path { + path: syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: id, - arguments: PathArguments::None, + arguments: syn::PathArguments::None, }] .into_iter() .collect(), @@ -92,12 +96,12 @@ impl ImplSkeleton { let unsafety = unsafety.then_some(::default()); - let subject = Path { + let subject = syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: input.ident.clone(), - arguments: PathArguments::AngleBracketed( - AngleBracketedGenericArguments { + arguments: syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { colon2_token: None, lt_token: Default::default(), args: subject_args, @@ -115,12 +119,12 @@ impl ImplSkeleton { predicates: Punctuated::new(), }); - let contents = Block { + let contents = syn::Block { brace_token: Default::default(), stmts: Vec::new(), }; - let requirements = Block { + let requirements = syn::Block { brace_token: Default::default(), stmts: Vec::new(), }; @@ -142,7 +146,11 @@ impl ImplSkeleton { /// /// If the type is concrete, a verifying statement is added for it. /// Otherwise, it is added to the where clause. - pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { + pub fn require_bound( + &mut self, + target: syn::Type, + bound: TypeParamBound, + ) { let mut visitor = ConcretenessVisitor { skeleton: self, is_concrete: true, @@ -154,7 +162,7 @@ impl ImplSkeleton { if visitor.is_concrete { // Add a concrete requirement for this bound. - self.requirements.stmts.push(parse_quote! { + self.requirements.stmts.push(syn::parse_quote! { const _: fn() = || { fn assert_impl() {} assert_impl::<#target>(); @@ -164,7 +172,7 @@ impl ImplSkeleton { // Add this bound to the `where` clause. let mut bounds = Punctuated::new(); bounds.push(bound); - let pred = WherePredicate::Type(PredicateType { + let pred = WherePredicate::Type(syn::PredicateType { lifetimes: None, bounded_ty: target, colon_token: Default::default(), @@ -196,9 +204,9 @@ impl ImplSkeleton { let lifetime = self.new_lifetime(prefix); let mut bounds = bounds.into_iter().peekable(); let param = if bounds.peek().is_some() { - parse_quote! { #lifetime: #(#bounds)+* } + syn::parse_quote! { #lifetime: #(#bounds)+* } } else { - parse_quote! { #lifetime } + syn::parse_quote! { #lifetime } }; (lifetime, param) } diff --git a/macros/src/repr.rs b/macros/src/repr.rs index 80c900eb6..b699b571b 100644 --- a/macros/src/repr.rs +++ b/macros/src/repr.rs @@ -1,7 +1,10 @@ //! Determining the memory layout of a type. use proc_macro2::Span; -use syn::{punctuated::Punctuated, spanned::Spanned, *}; +use syn::{ + punctuated::Punctuated, spanned::Spanned, Attribute, Error, LitInt, Meta, + Token, +}; //----------- Repr ----------------------------------------------------------- @@ -19,7 +22,10 @@ impl Repr { /// Determine the representation for a type from its attributes. /// /// This will fail if a stable representation cannot be found. - pub fn determine(attrs: &[Attribute], bound: &str) -> Result { + pub fn determine( + attrs: &[Attribute], + bound: &str, + ) -> Result { let mut repr = None; for attr in attrs { if !attr.path().is_ident("repr") { @@ -57,7 +63,7 @@ impl Repr { || meta.path.is_ident("aligned") => { let span = meta.span(); - let lit: LitInt = parse2(meta.tokens)?; + let lit: LitInt = syn::parse2(meta.tokens)?; let n: usize = lit.base10_parse()?; if n != 1 { return Err(Error::new(span, From 8bf87bee84595a1e28ee3db1d828a51e72e192f2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:46:27 +0100 Subject: [PATCH 074/111] [macros/lib.rs] Remove the last 'syn' glob import --- macros/src/lib.rs | 122 +++++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 3fa1bc18b..a23e6902a 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -5,7 +5,7 @@ use proc_macro as pm; use proc_macro2::TokenStream; use quote::{format_ident, ToTokens}; -use syn::*; +use syn::{Error, Ident, Result}; mod impls; use impls::ImplSkeleton; @@ -20,16 +20,16 @@ use repr::Repr; #[proc_macro_derive(SplitBytes)] pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'SplitBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'SplitBytes' can only be 'derive'd for 'struct's", @@ -47,7 +47,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -58,14 +58,14 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } // Define 'parse_bytes()'. let init_vars = builder.init_vars(); let tys = data.fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -82,7 +82,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -92,16 +92,16 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(ParseBytes)] pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'ParseBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'ParseBytes' can only be 'derive'd for 'struct's", @@ -119,7 +119,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -130,19 +130,19 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -165,7 +165,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { let tys = builder.sized_fields().map(|f| &f.ty); let unsized_ty = &builder.unsized_field().unwrap().ty; let unsized_init_var = builder.unsized_init_var().unwrap(); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -185,7 +185,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -195,16 +195,16 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(SplitBytesByRef)] pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'SplitBytesByRef' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'SplitBytesByRef' can only be 'derive'd for 'struct's", @@ -216,8 +216,9 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); - skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::SplitBytesByRef)); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::SplitBytesByRef + )); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -226,13 +227,13 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytesByRef), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -255,7 +256,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'split_bytes_by_ref()'. let tys = data.sized_fields().map(|f| &f.ty); let unsized_ty = &data.unsized_field().unwrap().ty; - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -289,7 +290,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'split_bytes_by_mut()'. let tys = data.sized_fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_mut( bytes: &mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -324,7 +325,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -334,16 +335,16 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(ParseBytesByRef)] pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'ParseBytesByRef' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'ParseBytesByRef' can only be 'derive'd for 'struct's", @@ -355,8 +356,9 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); - skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::ParseBytesByRef)); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::ParseBytesByRef + )); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -365,19 +367,19 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytesByRef), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::ParseBytesByRef), + syn::parse_quote!(::domain::new_base::wire::ParseBytesByRef), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -395,7 +397,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn ptr_with_address( &self, addr: *const (), @@ -410,7 +412,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'parse_bytes_by_ref()'. let tys = data.sized_fields().map(|f| &f.ty); let unsized_ty = &data.unsized_field().unwrap().ty; - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -444,7 +446,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'parse_bytes_by_mut()'. let tys = data.sized_fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_mut( bytes: &mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -478,7 +480,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'ptr_with_address()'. let unsized_member = data.unsized_member(); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(&self.#unsized_member, addr) @@ -489,7 +491,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -499,16 +501,16 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(BuildBytes)] pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'BuildBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'BuildBytes' can only be 'derive'd for 'struct's", @@ -519,7 +521,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::BuildBytes)); + Some(syn::parse_quote!(::domain::new_base::wire::BuildBytes)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -531,14 +533,14 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::BuildBytes), + syn::parse_quote!(::domain::new_base::wire::BuildBytes), ); } // Define 'build_bytes()'. let members = data.members(); let tys = data.fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn build_bytes<#lifetime>( &self, mut bytes: & #lifetime mut [::domain::__core::primitive::u8], @@ -555,7 +557,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -565,16 +567,16 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(AsBytes)] pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'AsBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'AsBytes' can only be 'derive'd for 'struct's", @@ -587,13 +589,13 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::AsBytes)); + Some(syn::parse_quote!(::domain::new_base::wire::AsBytes)); // Establish bounds on the fields. for field in data.fields.iter() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::AsBytes), + syn::parse_quote!(::domain::new_base::wire::AsBytes), ); } @@ -602,7 +604,7 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -611,6 +613,6 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { //----------- Utility Functions ---------------------------------------------- /// Add a `field_` prefix to member names. -fn field_prefixed(member: Member) -> Ident { +fn field_prefixed(member: syn::Member) -> Ident { format_ident!("field_{}", member) } From 8afc305d60d0d2be93a83a6afe3ec33a85b150b9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 12:09:17 +0100 Subject: [PATCH 075/111] [new_base/parse] Rework '{Parse,Split}FromMessage' Instead of passing the input selection as a range for parsing, the whole message is cut (using 'Message::slice_to()') and only the start is indicated. This ensures that we never cross the end of the range. It also implicitly dictates that compressed names are not allowed to reference future locations in messages. In addition, both the parsing traits now use offsets into the message _contents_ rather than the whole message. They can avoid 'as_bytes()' everywhere and have better guarantees of success. It also ensures the message header can never be selected for parsing. --- src/new_base/charstr.rs | 15 ++++++-------- src/new_base/message.rs | 17 +++++++++++++-- src/new_base/name/reversed.rs | 25 ++++++++++------------ src/new_base/parse/mod.rs | 36 ++++++++++++++++---------------- src/new_base/question.rs | 9 +++----- src/new_base/record.rs | 24 ++++++++------------- src/new_edns/mod.rs | 12 +++++------ src/new_rdata/basic.rs | 39 +++++++++++++++++------------------ src/new_rdata/mod.rs | 26 +++++++++++------------ 9 files changed, 99 insertions(+), 104 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 979e8be20..ce7e1fba7 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -1,14 +1,11 @@ //! DNS "character strings". -use core::{fmt, ops::Range}; +use core::fmt; use super::{ build::{self, BuildIntoMessage}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, - TruncationError, - }, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, }; @@ -28,7 +25,7 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = &message.as_bytes()[start..]; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = Self::split_bytes(bytes)?; Ok((this, bytes.len() - rest.len())) } @@ -37,11 +34,11 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { impl<'a> ParseFromMessage<'a> for &'a CharStr { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 9c27d384f..ab1aa903c 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -2,9 +2,9 @@ use core::fmt; -use domain_macros::{AsBytes, *}; +use domain_macros::*; -use super::wire::U16; +use super::wire::{AsBytes, ParseBytesByRef, U16}; //----------- Message -------------------------------------------------------- @@ -19,6 +19,19 @@ pub struct Message { pub contents: [u8], } +//--- Interaction + +impl Message { + /// Truncate the contents of this message to the given size. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to(&self, size: usize) -> &Self { + let bytes = &self.as_bytes()[..12 + size]; + Self::parse_bytes_by_ref(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } +} + //----------- Header --------------------------------------------------------- /// A DNS message header. diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index ba7cdb8c6..6432219a3 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -5,16 +5,13 @@ use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, - ops::{Deref, Range}, + ops::Deref, }; use crate::new_base::{ build::{self, BuildIntoMessage}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, - TruncationError, - }, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, }; @@ -255,13 +252,13 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { // disallow a name to point to data _after_ it. Standard name // compressors will never generate such pointers. - let message = message.as_bytes(); + let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; - let orig_end = message.len() - rest.len(); + let orig_end = contents.len() - rest.len(); // Traverse compression pointers. let mut old_start = start; @@ -272,7 +269,7 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; continue; @@ -288,17 +285,17 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { impl<'a> ParseFromMessage<'a> for RevNameBuf { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { // See 'split_from_message()' for details. The only differences are // in the range of the first iteration, and the check that the first // iteration exactly covers the input range. - let message = message.as_bytes(); + let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. - let bytes = message.get(range.clone()).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; if !rest.is_empty() { @@ -307,7 +304,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Traverse compression pointers. - let mut old_start = range.start; + let mut old_start = start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. if start >= old_start { @@ -315,7 +312,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; continue; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index d36dd9543..7e5a08d7f 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,7 +1,5 @@ //! Parsing DNS messages from the wire format. -use core::ops::Range; - mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -14,7 +12,7 @@ pub use record::{ParseRecord, ParseRecords, VisitRecord}; pub use super::wire::ParseError; use super::{ - wire::{AsBytes, ParseBytesByRef, SplitBytesByRef}, + wire::{ParseBytesByRef, SplitBytesByRef}, Message, }; @@ -22,11 +20,14 @@ use super::{ /// A type that can be parsed from a DNS message. pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { - /// Parse a value of [`Self`] from the start of a byte string within a - /// particular DNS message. + /// Parse a value from the start of a byte string within a DNS message. + /// + /// The byte string to parse is `message.contents[start..]`. The previous + /// data in the message can be used for resolving compressed names. /// - /// If parsing is successful, the parsed value and the rest of the string - /// are returned. Otherwise, a [`ParseError`] is returned. + /// If parsing is successful, the parsed value and the offset for the rest + /// of the input are returned. If `len` bytes were parsed to form `self`, + /// `start + len` should be the returned offset. fn split_from_message( message: &'a Message, start: usize, @@ -35,14 +36,15 @@ pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { /// A type that can be parsed from a string in a DNS message. pub trait ParseFromMessage<'a>: Sized { - /// Parse a value of [`Self`] from a byte string within a particular DNS - /// message. + /// Parse a value from a byte string within a DNS message. + /// + /// The byte string to parse is `message.contents[start..]`. The previous + /// data in the message can be used for resolving compressed names. /// - /// If parsing is successful, the parsed value is returned. Otherwise, a - /// [`ParseError`] is returned. + /// If parsing is successful, the parsed value is returned. fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result; } @@ -51,20 +53,18 @@ impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let message = message.as_bytes(); - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = T::split_bytes_by_ref(bytes)?; - Ok((this, message.len() - rest.len())) + Ok((this, bytes.len() - rest.len())) } } impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let message = message.as_bytes(); - let bytes = message.get(range).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; T::parse_bytes_by_ref(bytes) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 0dad0910a..b961f0af7 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,7 +1,5 @@ //! DNS questions. -use core::ops::Range; - use domain_macros::*; use super::{ @@ -66,12 +64,11 @@ where { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (qname, rest) = N::split_from_message(message, range.start)?; + let (qname, rest) = N::split_from_message(message, start)?; let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let &qclass = - <&QClass>::parse_from_message(message, rest..range.end)?; + let &qclass = <&QClass>::parse_from_message(message, rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 2d84e0934..5686f09f3 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -1,9 +1,6 @@ //! DNS records. -use core::{ - borrow::Borrow, - ops::{Deref, Range}, -}; +use core::{borrow::Borrow, ops::Deref}; use super::{ build::{self, BuildIntoMessage}, @@ -78,8 +75,9 @@ where let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; let (&size, rest) = <&U16>::split_from_message(message, rest)?; let size: usize = size.get().into(); - let rdata = if message.as_bytes().len() - rest >= size { - D::parse_record_data(message, rest..rest + size, rtype)? + let rdata = if message.contents.len() - rest >= size { + let message = message.slice_to(rest + size); + D::parse_record_data(message, rest, rtype)? } else { return Err(ParseError); }; @@ -95,15 +93,11 @@ where { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let message = &message.as_bytes()[..range.end]; - let message = Message::parse_bytes_by_ref(message) - .expect("The input range ends past the message header"); + let (this, rest) = Self::split_from_message(message, start)?; - let (this, rest) = Self::split_from_message(message, range.start)?; - - if rest == range.end { + if rest == message.contents.len() { Ok(this) } else { Err(ParseError) @@ -334,10 +328,10 @@ pub trait ParseRecordData<'a>: Sized { /// Parse DNS record data of the given type from a DNS message. fn parse_record_data( message: &'a Message, - range: Range, + start: usize, rtype: RType, ) -> Result { - let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; Self::parse_record_data_bytes(bytes, rtype) } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 781224360..9afd69167 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -2,7 +2,7 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use core::{fmt, ops::Range}; +use core::fmt; use domain_macros::*; @@ -54,20 +54,20 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, message.as_bytes().len() - rest.len())) + Ok((this, message.contents.len() - rest.len())) } } impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 14c7cdc9f..0f295ec5d 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -2,7 +2,7 @@ //! //! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). -use core::{fmt, ops::Range}; +use core::fmt; #[cfg(feature = "std")] use core::str::FromStr; @@ -123,9 +123,9 @@ pub struct Ns { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -167,9 +167,9 @@ pub struct CName { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -226,15 +226,15 @@ pub struct Soa { impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (mname, rest) = N::split_from_message(message, range.start)?; + let (mname, rest) = N::split_from_message(message, start)?; let (rname, rest) = N::split_from_message(message, rest)?; let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; let (&retry, rest) = <&U32>::split_from_message(message, rest)?; let (&expire, rest) = <&U32>::split_from_message(message, rest)?; - let &minimum = <&U32>::parse_from_message(message, rest..range.end)?; + let &minimum = <&U32>::parse_from_message(message, rest)?; Ok(Self { mname, @@ -349,9 +349,9 @@ pub struct Ptr { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -383,11 +383,11 @@ pub struct HInfo<'a> { impl<'a> ParseFromMessage<'a> for HInfo<'a> { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } @@ -437,11 +437,10 @@ pub struct Mx { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (&preference, rest) = - <&U16>::split_from_message(message, range.start)?; - let exchange = N::parse_from_message(message, rest..range.end)?; + let (&preference, rest) = <&U16>::split_from_message(message, start)?; + let exchange = N::parse_from_message(message, rest)?; Ok(Self { preference, exchange, @@ -496,11 +495,11 @@ impl Txt { impl<'a> ParseFromMessage<'a> for &'a Txt { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 23be1e18f..7617e4ad5 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,5 @@ //! Record data types. -use core::ops::Range; - use domain_macros::*; use crate::new_base::{ @@ -70,35 +68,35 @@ where { fn parse_record_data( message: &'a Message, - range: Range, + start: usize, rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from_message(message, range).map(Self::A), - RType::NS => Ns::parse_from_message(message, range).map(Self::Ns), + RType::A => <&A>::parse_from_message(message, start).map(Self::A), + RType::NS => Ns::parse_from_message(message, start).map(Self::Ns), RType::CNAME => { - CName::parse_from_message(message, range).map(Self::CName) + CName::parse_from_message(message, start).map(Self::CName) } RType::SOA => { - Soa::parse_from_message(message, range).map(Self::Soa) + Soa::parse_from_message(message, start).map(Self::Soa) } RType::WKS => { - <&Wks>::parse_from_message(message, range).map(Self::Wks) + <&Wks>::parse_from_message(message, start).map(Self::Wks) } RType::PTR => { - Ptr::parse_from_message(message, range).map(Self::Ptr) + Ptr::parse_from_message(message, start).map(Self::Ptr) } RType::HINFO => { - HInfo::parse_from_message(message, range).map(Self::HInfo) + HInfo::parse_from_message(message, start).map(Self::HInfo) } - RType::MX => Mx::parse_from_message(message, range).map(Self::Mx), + RType::MX => Mx::parse_from_message(message, start).map(Self::Mx), RType::TXT => { - <&Txt>::parse_from_message(message, range).map(Self::Txt) + <&Txt>::parse_from_message(message, start).map(Self::Txt) } RType::AAAA => { - <&Aaaa>::parse_from_message(message, range).map(Self::Aaaa) + <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) } - _ => <&UnknownRecordData>::parse_from_message(message, range) + _ => <&UnknownRecordData>::parse_from_message(message, start) .map(|data| Self::Unknown(rtype, data)), } } From 1037ba05dd889e7afb3390ef8b0e1f76dacccf42 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 12:12:45 +0100 Subject: [PATCH 076/111] [new_base/name] Delete unused 'parsed.rs' --- src/new_base/name/parsed.rs | 131 ------------------------------------ 1 file changed, 131 deletions(-) delete mode 100644 src/new_base/name/parsed.rs diff --git a/src/new_base/name/parsed.rs b/src/new_base/name/parsed.rs deleted file mode 100644 index abf592e5d..000000000 --- a/src/new_base/name/parsed.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! Domain names encoded in DNS messages. - -use zerocopy_derive::*; - -use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; - -//----------- ParsedName ----------------------------------------------------- - -/// A domain name in a DNS message. -#[derive(Debug, IntoBytes, Immutable, Unaligned)] -#[repr(transparent)] -pub struct ParsedName([u8]); - -//--- Constants - -impl ParsedName { - /// The maximum size of a parsed domain name in the wire format. - /// - /// This can occur if a compression pointer is used to point to a root - /// name, even though such a representation is longer than copying the - /// root label into the name. - pub const MAX_SIZE: usize = 256; - - /// The root name. - pub const ROOT: &'static Self = { - // SAFETY: A root label is the shortest valid name. - unsafe { Self::from_bytes_unchecked(&[0u8]) } - }; -} - -//--- Construction - -impl ParsedName { - /// Assume a byte string is a valid [`ParsedName`]. - /// - /// # Safety - /// - /// The byte string must be correctly encoded in the wire format, and - /// within the size restriction (256 bytes or fewer). It must end with a - /// root label or a compression pointer. - pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { - // SAFETY: 'ParsedName' is 'repr(transparent)' to '[u8]', so casting a - // '[u8]' into a 'ParsedName' is sound. - core::mem::transmute(bytes) - } -} - -//--- Inspection - -impl ParsedName { - /// The size of this name in the wire format. - #[allow(clippy::len_without_is_empty)] - pub const fn len(&self) -> usize { - self.0.len() - } - - /// Whether this is the root label. - pub const fn is_root(&self) -> bool { - self.0.len() == 1 - } - - /// Whether this is a compression pointer. - pub const fn is_pointer(&self) -> bool { - self.0.len() == 2 - } - - /// The wire format representation of the name. - pub const fn as_bytes(&self) -> &[u8] { - &self.0 - } -} - -//--- Parsing - -impl<'a> SplitFrom<'a> for &'a ParsedName { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - // Iterate through the labels in the name. - let mut index = 0usize; - loop { - if index >= ParsedName::MAX_SIZE || index >= bytes.len() { - return Err(ParseError); - } - let length = bytes[index]; - if length == 0 { - // This was the root label. - index += 1; - break; - } else if length < 0x40 { - // This was the length of the label. - index += 1 + length as usize; - } else if length >= 0xC0 { - // This was a compression pointer. - if index + 1 >= bytes.len() { - return Err(ParseError); - } - index += 2; - break; - } else { - // This was a reserved or deprecated label type. - return Err(ParseError); - } - } - - let (name, bytes) = bytes.split_at(index); - // SAFETY: 'bytes' has been confirmed to be correctly encoded. - Ok((unsafe { ParsedName::from_bytes_unchecked(name) }, bytes)) - } -} - -impl<'a> ParseFrom<'a> for &'a ParsedName { - fn parse_from(bytes: &'a [u8]) -> Result { - Self::split_from(bytes).and_then(|(name, rest)| { - rest.is_empty().then_some(name).ok_or(ParseError) - }) - } -} - -//--- Conversion to and from bytes - -impl AsRef<[u8]> for ParsedName { - /// The bytes in the name in the wire format. - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl<'a> From<&'a ParsedName> for &'a [u8] { - fn from(name: &'a ParsedName) -> Self { - name.as_bytes() - } -} From 5faf44d2739dae80f9b2bacc3536348ddb6757d9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 14:53:11 +0100 Subject: [PATCH 077/111] [new_base/wire] Define 'SizePrefixed' --- src/new_base/name/reversed.rs | 1 + src/new_base/record.rs | 59 +++---- src/new_base/wire/mod.rs | 5 +- src/new_base/wire/size_prefixed.rs | 248 +++++++++++++++++++++++++++++ src/new_edns/mod.rs | 96 ++++++----- 5 files changed, 317 insertions(+), 92 deletions(-) create mode 100644 src/new_base/wire/size_prefixed.rs diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6432219a3..e851911f3 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -269,6 +269,7 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 5686f09f3..f65d1c5d5 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -8,7 +8,7 @@ use super::{ parse::{ParseFromMessage, SplitFromMessage}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, - SplitBytes, SplitBytesByRef, TruncationError, U16, U32, + SizePrefixed, SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, Message, }; @@ -73,16 +73,13 @@ where let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; - let (&size, rest) = <&U16>::split_from_message(message, rest)?; - let size: usize = size.get().into(); - let rdata = if message.contents.len() - rest >= size { - let message = message.slice_to(rest + size); - D::parse_record_data(message, rest, rtype)? - } else { - return Err(ParseError); - }; + let rdata_start = rest; + let (_, rest) = + <&SizePrefixed<[u8]>>::split_from_message(message, rest)?; + let message = message.slice_to(rest); + let rdata = D::parse_record_data(message, rdata_start, rtype)?; - Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest + size)) + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } @@ -95,13 +92,14 @@ where message: &'a Message, start: usize, ) -> Result { - let (this, rest) = Self::split_from_message(message, start)?; + let (rname, rest) = N::split_from_message(message, start)?; + let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; + let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; + let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let _ = <&SizePrefixed<[u8]>>::parse_from_message(message, rest)?; + let rdata = D::parse_record_data(message, rest, rtype)?; - if rest == message.contents.len() { - Ok(this) - } else { - Err(ParseError) - } + Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } } @@ -148,13 +146,7 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() < size { - return Err(ParseError); - } - - let (rdata, rest) = rest.split_at(size); + let (rdata, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) @@ -171,13 +163,8 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() != size { - return Err(ParseError); - } - - let rdata = D::parse_record_data_bytes(rest, rtype)?; + let rdata = <&SizePrefixed<[u8]>>::parse_bytes(rest)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -198,17 +185,9 @@ where bytes = self.rtype.as_bytes().build_bytes(bytes)?; bytes = self.rclass.as_bytes().build_bytes(bytes)?; bytes = self.ttl.as_bytes().build_bytes(bytes)?; + bytes = SizePrefixed::new(&self.rdata).build_bytes(bytes)?; - let (size, bytes) = - U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; - let bytes_len = bytes.len(); - - let rest = self.rdata.build_bytes(bytes)?; - *size = u16::try_from(bytes_len - rest.len()) - .expect("the record data never exceeds 64KiB") - .into(); - - Ok(rest) + Ok(bytes) } } diff --git a/src/new_base/wire/mod.rs b/src/new_base/wire/mod.rs index 4d5be5c25..41f131af7 100644 --- a/src/new_base/wire/mod.rs +++ b/src/new_base/wire/mod.rs @@ -1,4 +1,4 @@ -//! The basic wire format of network protocols. +//! Low-level byte serialization. //! //! This is a low-level module providing simple and efficient mechanisms to //! parse data from and build data into byte sequences. It takes inspiration @@ -79,3 +79,6 @@ pub use parse::{ mod ints; pub use ints::{U16, U32, U64}; + +mod size_prefixed; +pub use size_prefixed::SizePrefixed; diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs new file mode 100644 index 000000000..5e4fc217e --- /dev/null +++ b/src/new_base/wire/size_prefixed.rs @@ -0,0 +1,248 @@ +//! Working with (U16-)size-prefixed data. + +use core::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut}, +}; + +use super::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SplitBytes, + SplitBytesByRef, TruncationError, U16, +}; + +//----------- SizePrefixed --------------------------------------------------- + +/// A wrapper adding a 16-bit size prefix to a message. +/// +/// This is a common element in DNS messages (e.g. for record data and EDNS +/// options). When serialized as bytes, the inner value is prefixed with a +/// 16-bit network-endian integer indicating the length of the inner value in +/// bytes. +#[derive(Copy, Clone)] +#[repr(C)] +pub struct SizePrefixed { + /// The size prefix (needed for 'ParseBytesByRef' / 'AsBytes'). + /// + /// This value is always consistent with the size of 'data' if it is + /// (de)serialized in-place. By the bounds on 'ParseBytesByRef' and + /// 'AsBytes', the serialized size is the same as 'size_of_val(&data)'. + size: U16, + + /// The inner data. + data: T, +} + +//--- Construction + +impl SizePrefixed { + const VALID_SIZE: () = assert!(core::mem::size_of::() < 65536); + + /// Construct a [`SizePrefixed`]. + /// + /// # Panics + /// + /// Panics if the data is 64KiB or more in size. + pub const fn new(data: T) -> Self { + // Force the 'VALID_SIZE' assertion to be evaluated. + #[allow(clippy::let_unit_value)] + let _ = Self::VALID_SIZE; + + Self { + size: U16::new(core::mem::size_of::() as u16), + data, + } + } +} + +//--- Conversion from the inner data + +impl From for SizePrefixed { + fn from(value: T) -> Self { + Self::new(value) + } +} + +//--- Access to the inner data + +impl Deref for SizePrefixed { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl DerefMut for SizePrefixed { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.data + } +} + +impl Borrow for SizePrefixed { + fn borrow(&self) -> &T { + &self.data + } +} + +impl BorrowMut for SizePrefixed { + fn borrow_mut(&mut self) -> &mut T { + &mut self.data + } +} + +impl AsRef for SizePrefixed { + fn as_ref(&self) -> &T { + &self.data + } +} + +impl AsMut for SizePrefixed { + fn as_mut(&mut self) -> &mut T { + &mut self.data + } +} + +//--- Parsing from bytes + +impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { + fn parse_bytes(bytes: &'b [u8]) -> Result { + let (size, rest) = U16::split_bytes(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let data = T::parse_bytes(bytes)?; + Ok(Self { size, data }) + } +} + +impl<'b, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed { + fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + let (size, rest) = U16::split_bytes(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at(size.get() as usize); + let data = T::parse_bytes(data)?; + Ok((Self { size, data }, rest)) + } +} + +unsafe impl ParseBytesByRef for SizePrefixed { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_ref(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let last = T::parse_bytes_by_ref(rest)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok(unsafe { &*(ptr as *const Self) }) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_mut(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let last = T::parse_bytes_by_mut(rest)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok(unsafe { &mut *(ptr as *const Self as *mut Self) }) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + self.data.ptr_with_address(addr) as *const Self + } +} + +unsafe impl SplitBytesByRef for SizePrefixed { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_ref(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at(size.get() as usize); + let last = T::parse_bytes_by_ref(data)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok((unsafe { &*(ptr as *const Self) }, rest)) + } + + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_mut(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at_mut(size.get() as usize); + let last = T::parse_bytes_by_mut(data)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok((unsafe { &mut *(ptr as *const Self as *mut Self) }, rest)) + } +} + +//--- Building into byte strings + +impl BuildBytes for SizePrefixed { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + // Get the size area to fill in afterwards. + let (size_buf, data_buf) = + U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; + let data_buf_len = data_buf.len(); + let rest = self.data.build_bytes(data_buf)?; + let size = data_buf_len - rest.len(); + assert!(size < 65536, "Cannot serialize >=64KiB into 16-bit integer"); + *size_buf = U16::new(size as u16); + Ok(rest) + } +} + +unsafe impl AsBytes for SizePrefixed { + // For debugging, we check that the serialized size is correct. + #[cfg(debug_assertions)] + fn as_bytes(&self) -> &[u8] { + let size: usize = self.size.get().into(); + assert_eq!(size, core::mem::size_of_val(&self.data)); + + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 9afd69167..152cd5dae 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -11,7 +11,7 @@ use crate::{ parse::{ParseFromMessage, SplitFromMessage}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, - SplitBytes, TruncationError, U16, + SizePrefixed, SplitBytes, TruncationError, U16, }, Message, }, @@ -44,7 +44,7 @@ pub struct EdnsRecord<'a> { pub flags: EdnsFlags, /// Extended DNS options. - pub options: &'a Opt, + pub options: SizePrefixed<&'a Opt>, } //--- Parsing from DNS messages @@ -84,15 +84,7 @@ impl<'a> SplitBytes<'a> for EdnsRecord<'a> { let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; let (&version, rest) = <&u8>::split_bytes(rest)?; let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; - - // Split the record size and data. - let (&size, rest) = <&U16>::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() < size { - return Err(ParseError); - } - let (options, rest) = rest.split_at(size); - let options = Opt::parse_bytes_by_ref(options)?; + let (options, rest) = >::split_bytes(rest)?; Ok(( Self { @@ -116,14 +108,7 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; let (&version, rest) = <&u8>::split_bytes(rest)?; let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; - - // Split the record size and data. - let (&size, rest) = <&U16>::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() != size { - return Err(ParseError); - } - let options = Opt::parse_bytes_by_ref(rest)?; + let options = >::parse_bytes(rest)?; Ok(Self { max_udp_payload, @@ -135,6 +120,26 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { } } +//--- Building into bytes + +impl BuildBytes for EdnsRecord<'_> { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + // Add the record name (root) and the record type. + bytes = [0, 0, 41].as_slice().build_bytes(bytes)?; + + bytes = self.max_udp_payload.build_bytes(bytes)?; + bytes = self.ext_rcode.build_bytes(bytes)?; + bytes = self.version.build_bytes(bytes)?; + bytes = self.flags.build_bytes(bytes)?; + bytes = self.options.build_bytes(bytes)?; + + Ok(bytes) + } +} + //----------- EdnsFlags ------------------------------------------------------ /// Extended DNS flags describing a message. @@ -239,25 +244,22 @@ impl EdnsOption<'_> { impl<'b> ParseBytes<'b> for EdnsOption<'b> { fn parse_bytes(bytes: &'b [u8]) -> Result { let (code, rest) = OptionCode::split_bytes(bytes)?; - let (size, rest) = U16::split_bytes(rest)?; - if rest.len() != size.get() as usize { - return Err(ParseError); - } + let data = <&SizePrefixed<[u8]>>::parse_bytes(rest)?; match code { - OptionCode::COOKIE => match size.get() { - 8 => CookieRequest::parse_bytes_by_ref(rest) + OptionCode::COOKIE => match data.len() { + 8 => CookieRequest::parse_bytes_by_ref(data) .map(Self::CookieRequest), - 16..=40 => Cookie::parse_bytes_by_ref(rest).map(Self::Cookie), + 16..=40 => Cookie::parse_bytes_by_ref(data).map(Self::Cookie), _ => Err(ParseError), }, OptionCode::EXT_ERROR => { - ExtError::parse_bytes_by_ref(rest).map(Self::ExtError) + ExtError::parse_bytes_by_ref(data).map(Self::ExtError) } _ => { - let data = UnknownOption::parse_bytes_by_ref(rest)?; + let data = UnknownOption::parse_bytes_by_ref(data)?; Ok(Self::Unknown(code, data)) } } @@ -267,32 +269,25 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { let (code, rest) = OptionCode::split_bytes(bytes)?; - let (size, rest) = U16::split_bytes(rest)?; - if rest.len() < size.get() as usize { - return Err(ParseError); - } - let (bytes, rest) = rest.split_at(size.get() as usize); - - match code { - OptionCode::COOKIE => match size.get() { - 8 => CookieRequest::parse_bytes_by_ref(bytes) - .map(Self::CookieRequest), - 16..=40 => { - Cookie::parse_bytes_by_ref(bytes).map(Self::Cookie) - } - _ => Err(ParseError), + let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; + + let this = match code { + OptionCode::COOKIE => match data.len() { + 8 => <&CookieRequest>::parse_bytes(data) + .map(Self::CookieRequest)?, + 16..=40 => <&Cookie>::parse_bytes(data).map(Self::Cookie)?, + _ => return Err(ParseError), }, OptionCode::EXT_ERROR => { - ExtError::parse_bytes_by_ref(bytes).map(Self::ExtError) + <&ExtError>::parse_bytes(data).map(Self::ExtError)? } - _ => { - let data = UnknownOption::parse_bytes_by_ref(bytes)?; - Ok(Self::Unknown(code, data)) - } - } - .map(|this| (this, rest)) + _ => <&UnknownOption>::parse_bytes(data) + .map(|data| Self::Unknown(code, data))?, + }; + + Ok((this, rest)) } } @@ -311,9 +306,8 @@ impl BuildBytes for EdnsOption<'_> { Self::ExtError(this) => this.as_bytes(), Self::Unknown(_, this) => this.as_bytes(), }; + bytes = SizePrefixed::new(data).build_bytes(bytes)?; - bytes = U16::new(data.len() as u16).build_bytes(bytes)?; - bytes = data.build_bytes(bytes)?; Ok(bytes) } } From a98d246e5bab2fd192565ed09814491ff2a69338 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 15:33:18 +0100 Subject: [PATCH 078/111] [new_rdata/edns] Implement iteration and formatting --- src/new_base/record.rs | 29 ++++++++----- src/new_rdata/edns.rs | 93 +++++++++++++++++++++++++++++++++++++++++- src/new_rdata/mod.rs | 15 +++++-- 3 files changed, 122 insertions(+), 15 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index f65d1c5d5..8fa3f30f4 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -219,35 +219,44 @@ pub struct RType { //--- Associated Constants impl RType { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + /// The type of an [`A`](crate::new_rdata::A) record. - pub const A: Self = Self { code: U16::new(1) }; + pub const A: Self = Self::new(1); /// The type of an [`Ns`](crate::new_rdata::Ns) record. - pub const NS: Self = Self { code: U16::new(2) }; + pub const NS: Self = Self::new(2); /// The type of a [`CName`](crate::new_rdata::CName) record. - pub const CNAME: Self = Self { code: U16::new(5) }; + pub const CNAME: Self = Self::new(5); /// The type of an [`Soa`](crate::new_rdata::Soa) record. - pub const SOA: Self = Self { code: U16::new(6) }; + pub const SOA: Self = Self::new(6); /// The type of a [`Wks`](crate::new_rdata::Wks) record. - pub const WKS: Self = Self { code: U16::new(11) }; + pub const WKS: Self = Self::new(11); /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. - pub const PTR: Self = Self { code: U16::new(12) }; + pub const PTR: Self = Self::new(12); /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. - pub const HINFO: Self = Self { code: U16::new(13) }; + pub const HINFO: Self = Self::new(13); /// The type of a [`Mx`](crate::new_rdata::Mx) record. - pub const MX: Self = Self { code: U16::new(15) }; + pub const MX: Self = Self::new(15); /// The type of a [`Txt`](crate::new_rdata::Txt) record. - pub const TXT: Self = Self { code: U16::new(16) }; + pub const TXT: Self = Self::new(16); /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. - pub const AAAA: Self = Self { code: U16::new(28) }; + pub const AAAA: Self = Self::new(28); + + /// The type of an [`Opt`](crate::new_rdata::Opt) record. + pub const OPT: Self = Self::new(41); } //----------- RClass --------------------------------------------------------- diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index c53a715a7..5c2ccce29 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -2,9 +2,17 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). +use core::{fmt, iter::FusedIterator}; + use domain_macros::*; -use crate::new_base::build::{self, BuildIntoMessage, TruncationError}; +use crate::{ + new_base::{ + build::{self, BuildIntoMessage, TruncationError}, + wire::{ParseError, SplitBytes}, + }, + new_edns::EdnsOption, +}; //----------- Opt ------------------------------------------------------------ @@ -18,7 +26,23 @@ pub struct Opt { contents: [u8], } -// TODO: Parsing the EDNS options. +//--- Inspection + +impl Opt { + /// Traverse the options in this record. + pub fn options(&self) -> EdnsOptionsIter<'_> { + EdnsOptionsIter::new(&self.contents) + } +} + +//--- Formatting + +impl fmt::Debug for Opt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Opt").field(&self.options()).finish() + } +} + // TODO: Formatting. //--- Building into DNS messages @@ -31,3 +55,68 @@ impl BuildIntoMessage for Opt { self.contents.build_into_message(builder) } } + +//----------- EdnsOptionsIter ------------------------------------------------ + +/// An iterator over EDNS options in an [`Opt`] record. +#[derive(Clone)] +pub struct EdnsOptionsIter<'a> { + /// The serialized options to parse from. + options: &'a [u8], +} + +//--- Construction + +impl<'a> EdnsOptionsIter<'a> { + /// Construct a new [`EdnsOptionsIter`]. + pub const fn new(options: &'a [u8]) -> Self { + Self { options } + } +} + +//--- Inspection + +impl<'a> EdnsOptionsIter<'a> { + /// The serialized options yet to be parsed. + pub const fn remaining(&self) -> &'a [u8] { + self.options + } +} + +//--- Formatting + +impl fmt::Debug for EdnsOptionsIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut entries = f.debug_set(); + for option in self.clone() { + match option { + Ok(option) => entries.entry(&option), + Err(_err) => entries.entry(&format_args!("")), + }; + } + entries.finish() + } +} + +//--- Iteration + +impl<'a> Iterator for EdnsOptionsIter<'a> { + type Item = Result, ParseError>; + + fn next(&mut self) -> Option { + if !self.options.is_empty() { + let options = core::mem::take(&mut self.options); + match EdnsOption::split_bytes(options) { + Ok((option, rest)) => { + self.options = rest; + Some(Ok(option)) + } + Err(err) => Some(Err(err)), + } + } else { + None + } + } +} + +impl FusedIterator for EdnsOptionsIter<'_> {} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 7617e4ad5..ac1804ea8 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -18,7 +18,7 @@ mod ipv6; pub use ipv6::Aaaa; mod edns; -pub use edns::Opt; +pub use edns::{EdnsOptionsIter, Opt}; //----------- RecordData ----------------------------------------------------- @@ -56,6 +56,9 @@ pub enum RecordData<'a, N> { /// The IPv6 address of a host responsible for this domain. Aaaa(&'a Aaaa), + /// Extended DNS options. + Opt(&'a Opt), + /// Data for an unknown DNS record type. Unknown(RType, &'a UnknownRecordData), } @@ -96,6 +99,9 @@ where RType::AAAA => { <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) } + RType::OPT => { + <&Opt>::parse_from_message(message, start).map(Self::Opt) + } _ => <&UnknownRecordData>::parse_from_message(message, start) .map(|data| Self::Unknown(rtype, data)), } @@ -116,6 +122,7 @@ where RType::MX => Mx::parse_bytes(bytes).map(Self::Mx), RType::TXT => <&Txt>::parse_bytes(bytes).map(Self::Txt), RType::AAAA => <&Aaaa>::parse_bytes(bytes).map(Self::Aaaa), + RType::OPT => <&Opt>::parse_bytes(bytes).map(Self::Opt), _ => <&UnknownRecordData>::parse_bytes(bytes) .map(|data| Self::Unknown(rtype, data)), } @@ -137,9 +144,10 @@ impl BuildIntoMessage for RecordData<'_, N> { Self::Wks(r) => r.build_into_message(builder), Self::Ptr(r) => r.build_into_message(builder), Self::HInfo(r) => r.build_into_message(builder), + Self::Mx(r) => r.build_into_message(builder), Self::Txt(r) => r.build_into_message(builder), Self::Aaaa(r) => r.build_into_message(builder), - Self::Mx(r) => r.build_into_message(builder), + Self::Opt(r) => r.build_into_message(builder), Self::Unknown(_, r) => r.octets.build_into_message(builder), } } @@ -158,9 +166,10 @@ impl BuildBytes for RecordData<'_, N> { Self::Wks(r) => r.build_bytes(bytes), Self::Ptr(r) => r.build_bytes(bytes), Self::HInfo(r) => r.build_bytes(bytes), + Self::Mx(r) => r.build_bytes(bytes), Self::Txt(r) => r.build_bytes(bytes), Self::Aaaa(r) => r.build_bytes(bytes), - Self::Mx(r) => r.build_bytes(bytes), + Self::Opt(r) => r.build_bytes(bytes), Self::Unknown(_, r) => r.build_bytes(bytes), } } From 3ea33ac19fa5f097381dc16b42ddecb2c928f262 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 08:12:37 +0100 Subject: [PATCH 079/111] [new_base/build] Use 'BuildResult' to ensure 'commit()' is called This has already caught a missing commit (for 'Question'). --- src/new_base/build/builder.rs | 9 +++++- src/new_base/build/mod.rs | 34 +++++++++++++---------- src/new_base/charstr.rs | 7 ++--- src/new_base/name/reversed.rs | 12 +++----- src/new_base/question.rs | 8 +++--- src/new_base/record.rs | 12 +++----- src/new_rdata/basic.rs | 52 ++++++++++------------------------- src/new_rdata/edns.rs | 7 ++--- src/new_rdata/ipv6.rs | 9 ++---- src/new_rdata/mod.rs | 7 ++--- 10 files changed, 63 insertions(+), 94 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index e02da91db..8a0bbba33 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -12,6 +12,8 @@ use crate::new_base::{ Header, Message, }; +use super::BuildCommitted; + //----------- Builder -------------------------------------------------------- /// A DNS message builder. @@ -214,8 +216,13 @@ impl Builder<'_> { } /// Commit all appended content. - pub fn commit(&mut self) { + /// + /// For convenience, a unit type [`BuildCommitted`] is returned; it is + /// used as the return type of build functions to remind users to call + /// this method on success paths. + pub fn commit(&mut self) -> BuildCommitted { self.commit = self.context.size; + BuildCommitted } /// Mark bytes in the buffer as initialized. diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 35752d2e2..86e21ddfc 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -3,7 +3,7 @@ mod builder; pub use builder::{Builder, BuilderContext}; -pub use super::wire::TruncationError; +use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- @@ -13,28 +13,32 @@ pub trait BuildIntoMessage { /// /// If the builder has enough capacity to fit the message, it is appended /// and committed. Otherwise, a [`TruncationError`] is returned. - fn build_into_message( - &self, - builder: Builder<'_>, - ) -> Result<(), TruncationError>; + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult; } impl BuildIntoMessage for &T { - fn build_into_message( - &self, - builder: Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { (**self).build_into_message(builder) } } impl BuildIntoMessage for [u8] { - fn build_into_message( - &self, - mut builder: Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { builder.append_bytes(self)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } + +//----------- BuildResult ---------------------------------------------------- + +/// The result of building into a DNS message. +pub type BuildResult = Result; + +//----------- BuildCommitted ------------------------------------------------- + +/// The output of [`Builder::commit()`]. +/// +/// This is a stub type to remind users to call [`Builder::commit()`] in all +/// success paths of building functions. +#[derive(Debug)] +pub struct BuildCommitted; diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index ce7e1fba7..8df3c3d7c 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -3,7 +3,7 @@ use core::fmt; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, @@ -50,11 +50,10 @@ impl BuildIntoMessage for CharStr { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_bytes(&[self.octets.len() as u8])?; builder.append_bytes(&self.octets)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index e851911f3..aa58b24cf 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -9,7 +9,7 @@ use core::{ }; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, @@ -99,10 +99,9 @@ impl BuildIntoMessage for RevName { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_name(self)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -371,10 +370,7 @@ fn parse_segment<'a>( //--- Building into DNS messages impl BuildIntoMessage for RevNameBuf { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { (**self).build_into_message(builder) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index b961f0af7..720d46e14 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -3,10 +3,10 @@ use domain_macros::*; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, parse::{ParseFromMessage, SplitFromMessage}, - wire::{AsBytes, ParseError, TruncationError, U16}, + wire::{AsBytes, ParseError, U16}, Message, }; @@ -82,11 +82,11 @@ where fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.qname.build_into_message(builder.delegate())?; builder.append_bytes(self.qtype.as_bytes())?; builder.append_bytes(self.qclass.as_bytes())?; - Ok(()) + Ok(builder.commit()) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 8fa3f30f4..5e380ae0f 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -3,7 +3,7 @@ use core::{borrow::Borrow, ops::Deref}; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, parse::{ParseFromMessage, SplitFromMessage}, wire::{ @@ -113,7 +113,7 @@ where fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.rname.build_into_message(builder.delegate())?; builder.append_bytes(self.rtype.as_bytes())?; builder.append_bytes(self.rclass.as_bytes())?; @@ -129,8 +129,7 @@ where builder.appended_mut()[offset..offset + 2] .copy_from_slice(&size.to_be_bytes()); - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -372,10 +371,7 @@ impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { //--- Building into DNS messages impl BuildIntoMessage for UnparsedRecordData { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.0.build_into_message(builder) } } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 0f295ec5d..456da881c 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -13,12 +13,9 @@ use std::net::Ipv4Addr; use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, ParseBytes, ParseError, SplitBytes, TruncationError, U16, - U32, - }, + wire::{AsBytes, ParseBytes, ParseError, SplitBytes, U16, U32}, CharStr, Message, Serial, }; @@ -88,10 +85,7 @@ impl fmt::Display for A { //--- Building into DNS messages impl BuildIntoMessage for A { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } @@ -132,10 +126,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { //--- Building into DNS messages impl BuildIntoMessage for Ns { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -176,10 +167,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { //--- Building into DNS messages impl BuildIntoMessage for CName { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -254,7 +242,7 @@ impl BuildIntoMessage for Soa { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.mname.build_into_message(builder.delegate())?; self.rname.build_into_message(builder.delegate())?; builder.append_bytes(self.serial.as_bytes())?; @@ -262,8 +250,7 @@ impl BuildIntoMessage for Soa { builder.append_bytes(self.retry.as_bytes())?; builder.append_bytes(self.expire.as_bytes())?; builder.append_bytes(self.minimum.as_bytes())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -314,10 +301,7 @@ impl fmt::Debug for Wks { //--- Building into DNS messages impl BuildIntoMessage for Wks { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } @@ -358,10 +342,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { //--- Building into DNS messages impl BuildIntoMessage for Ptr { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -399,11 +380,10 @@ impl BuildIntoMessage for HInfo<'_> { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.cpu.build_into_message(builder.delegate())?; self.os.build_into_message(builder.delegate())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -454,11 +434,10 @@ impl BuildIntoMessage for Mx { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_bytes(self.preference.as_bytes())?; self.exchange.build_into_message(builder.delegate())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -508,10 +487,7 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { //--- Building into DNS messages impl BuildIntoMessage for Txt { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.content.build_into_message(builder) } } diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 5c2ccce29..43327b50c 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -8,7 +8,7 @@ use domain_macros::*; use crate::{ new_base::{ - build::{self, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage, BuildResult}, wire::{ParseError, SplitBytes}, }, new_edns::EdnsOption, @@ -48,10 +48,7 @@ impl fmt::Debug for Opt { //--- Building into DNS messages impl BuildIntoMessage for Opt { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.contents.build_into_message(builder) } } diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 788a1ca97..f91ae5e7b 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -11,8 +11,8 @@ use std::net::Ipv6Addr; use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, - wire::{AsBytes, TruncationError}, + build::{self, BuildIntoMessage, BuildResult}, + wire::AsBytes, }; //----------- Aaaa ----------------------------------------------------------- @@ -81,10 +81,7 @@ impl fmt::Display for Aaaa { //--- Building into DNS messages impl BuildIntoMessage for Aaaa { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index ac1804ea8..e4b94a538 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -3,7 +3,7 @@ use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, ParseRecordData, RType, @@ -132,10 +132,7 @@ where //--- Building record data impl BuildIntoMessage for RecordData<'_, N> { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { match self { Self::A(r) => r.build_into_message(builder), Self::Ns(r) => r.build_into_message(builder), From 6caee28ebc672d9e9e26705b2f98fc28af39b0f5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 13:03:49 +0100 Subject: [PATCH 080/111] [new_base/build] Define 'MessageBuilder' --- src/new_base/build/builder.rs | 16 +- src/new_base/build/message.rs | 243 +++++++++++++++++++++++++++++ src/new_base/build/mod.rs | 3 + src/new_base/record.rs | 13 +- src/new_base/wire/size_prefixed.rs | 61 +++++++- 5 files changed, 321 insertions(+), 15 deletions(-) create mode 100644 src/new_base/build/message.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 8a0bbba33..274bbd7c6 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -133,7 +133,12 @@ impl<'b> Builder<'b> { } /// The appended but uncommitted contents of the message, mutably. - pub fn appended_mut(&mut self) -> &mut [u8] { + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + pub unsafe fn appended_mut(&mut self) -> &mut [u8] { // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. let range = self.commit..self.context.size; unsafe { &mut (*self.message.as_ptr()).contents[range] } @@ -175,6 +180,15 @@ impl<'b> Builder<'b> { .expect("'message' represents a valid 'Message'") } + /// A pointer to the message, including any uncommitted contents. + /// + /// The first `commit` bytes of the message contents (also provided by + /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. + /// The remainder of the message is initialized and borrowed by `self`. + pub fn cur_message_ptr(&self) -> NonNull { + self.cur_message().into() + } + /// The builder context. pub fn context(&self) -> &BuilderContext { &*self.context diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs new file mode 100644 index 000000000..fd8fa3b45 --- /dev/null +++ b/src/new_base/build/message.rs @@ -0,0 +1,243 @@ +//! Building whole DNS messages. + +//----------- MessageBuilder ------------------------------------------------- + +use crate::new_base::{ + wire::TruncationError, Header, Message, Question, Record, +}; + +use super::{BuildIntoMessage, Builder, BuilderContext}; + +/// A builder for a whole DNS message. +/// +/// This is subtly different from a regular [`Builder`] -- it does not allow +/// for commits and so can always modify the entire message. It has methods +/// for adding entire questions and records to the message. +pub struct MessageBuilder<'b> { + /// The underlying [`Builder`]. + /// + /// Its commit point is always 0. + inner: Builder<'b>, +} + +//--- Initialization + +impl<'b> MessageBuilder<'b> { + /// Construct a [`MessageBuilder`] from raw parts. + /// + /// # Safety + /// + /// - `message` and `context` are paired together. + pub unsafe fn from_raw_parts( + message: &'b mut Message, + context: &'b mut BuilderContext, + ) -> Self { + // SAFETY: since 'commit' is 0, no part of the message is immutably + // borrowed; it is thus sound to represent as a mutable borrow. + let inner = + unsafe { Builder::from_raw_parts(message.into(), context, 0) }; + Self { inner } + } + + /// Initialize an empty [`MessageBuilder`]. + /// + /// The message header is left uninitialized. use [`Self::header_mut()`] + /// to initialize it. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + let inner = Builder::new(buffer, context); + Self { inner } + } +} + +//--- Inspection + +impl<'b> MessageBuilder<'b> { + /// The message header. + /// + /// The header can be modified by the builder, and so is only available + /// for a short lifetime. Note that it implements [`Copy`]. + pub fn header(&self) -> &Header { + self.inner.header() + } + + /// Mutable access to the message header. + pub fn header_mut(&mut self) -> &mut Header { + self.inner.header_mut() + } + + /// Uninitialized space in the message buffer. + /// + /// This can be filled manually, then marked as initialized using + /// [`Self::mark_appended()`]. + pub fn uninitialized(&mut self) -> &mut [u8] { + self.inner.uninitialized() + } + + /// The message built thus far. + pub fn message(&self) -> &Message { + self.inner.cur_message() + } + + /// The message built thus far, mutably. + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + pub unsafe fn message_mut(&mut self) -> &mut Message { + // SAFETY: Since no bytes are committed, and the rest of the message + // is borrowed mutably for 'self', we can use a mutable reference. + unsafe { self.inner.cur_message_ptr().as_mut() } + } + + /// The builder context. + pub fn context(&self) -> &BuilderContext { + self.inner.context() + } + + /// Decompose this builder into raw parts. + /// + /// This returns the message buffer and the context for this builder. The + /// two are linked, and the builder can be recomposed with + /// [`Self::raw_from_parts()`]. + pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { + let (mut message, context, _commit) = self.inner.into_raw_parts(); + // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed + // mutably for the lifetime 'b. Since the commit point is 0, there is + // no immutably-borrowed content in the message, so it can be turned + // into a regular reference. + (unsafe { message.as_mut() }, context) + } +} + +//--- Interaction + +impl MessageBuilder<'_> { + /// Mark bytes in the buffer as initialized. + /// + /// The given number of bytes from the beginning of + /// [`Self::uninitialized()`] will be marked as initialized, and will be + /// treated as appended content in the buffer. + /// + /// # Panics + /// + /// Panics if the uninitialized buffer is smaller than the given number of + /// initialized bytes. + pub fn mark_appended(&mut self, amount: usize) { + self.inner.mark_appended(amount) + } + + /// Limit the total message size. + /// + /// The message will not be allowed to exceed the given size, in bytes. + /// Only the message header and contents are counted; the enclosing UDP + /// or TCP packet size is not considered. If the message already exceeds + /// this size, a [`TruncationError`] is returned. + /// + /// This size will apply to all builders for this message (including those + /// that delegated to `self`). It will not be automatically revoked if + /// message building fails. + /// + /// # Panics + /// + /// Panics if the given size is less than 12 bytes. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + self.inner.limit_to(size) + } + + /// Append a question. + /// + /// # Panics + /// + /// Panics if the message contains any records (as questions must come + /// before all records). + pub fn append_question( + &mut self, + question: &Question, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + { + // Ensure there are no records present. + let header = self.header(); + let records = header.counts.answers + + header.counts.authorities + + header.counts.additional; + assert_eq!(records, 0); + + question.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.questions += 1; + Ok(()) + } + + /// Append an answer record. + /// + /// # Panics + /// + /// Panics if the message contains any authority or additional records. + pub fn append_answer( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Ensure there are no authority or additional records present. + let header = self.header(); + let records = header.counts.authorities + header.counts.additional; + assert_eq!(records, 0); + + record.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.answers += 1; + Ok(()) + } + + /// Append an authority record. + /// + /// # Panics + /// + /// Panics if the message contains any additional records. + pub fn append_authority( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Ensure there are no additional records present. + let header = self.header(); + let records = header.counts.additional; + assert_eq!(records, 0); + + record.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.authorities += 1; + Ok(()) + } + + /// Append an additional record. + pub fn append_additional( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + record.build_into_message(self.inner.delegate())?; + self.header_mut().counts.additional += 1; + Ok(()) + } +} diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 86e21ddfc..80b2b0942 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -3,6 +3,9 @@ mod builder; pub use builder::{Builder, BuilderContext}; +mod message; +pub use message::MessageBuilder; + use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 5e380ae0f..742a66977 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -118,17 +118,8 @@ where builder.append_bytes(self.rtype.as_bytes())?; builder.append_bytes(self.rclass.as_bytes())?; builder.append_bytes(self.ttl.as_bytes())?; - - // The offset of the record data size. - let offset = builder.appended().len(); - builder.append_bytes(&0u16.to_be_bytes())?; - self.rdata.build_into_message(builder.delegate())?; - let size = builder.appended().len() - 2 - offset; - let size = - u16::try_from(size).expect("the record data never exceeds 64KiB"); - builder.appended_mut()[offset..offset + 2] - .copy_from_slice(&size.to_be_bytes()); - + SizePrefixed::new(&self.rdata) + .build_into_message(builder.delegate())?; Ok(builder.commit()) } } diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 5e4fc217e..751a57395 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -5,6 +5,12 @@ use core::{ ops::{Deref, DerefMut}, }; +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseFromMessage, SplitFromMessage}, + Message, +}; + use super::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, TruncationError, U16, @@ -102,6 +108,37 @@ impl AsMut for SizePrefixed { } } +//--- Parsing from DNS messages + +impl<'b, T: ParseFromMessage<'b>> ParseFromMessage<'b> for SizePrefixed { + fn parse_from_message( + message: &'b Message, + start: usize, + ) -> Result { + let (&size, rest) = <&U16>::split_from_message(message, start)?; + if rest + size.get() as usize != message.contents.len() { + return Err(ParseError); + } + T::parse_from_message(message, rest).map(Self::new) + } +} + +impl<'b, T: ParseFromMessage<'b>> SplitFromMessage<'b> for SizePrefixed { + fn split_from_message( + message: &'b Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (&size, rest) = <&U16>::split_from_message(message, start)?; + let (start, rest) = (rest, rest + size.get() as usize); + if rest > message.contents.len() { + return Err(ParseError); + } + let message = message.slice_to(rest); + let data = T::parse_from_message(message, start)?; + Ok((Self::new(data), rest)) + } +} + //--- Parsing from bytes impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { @@ -110,8 +147,7 @@ impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { if rest.len() != size.get() as usize { return Err(ParseError); } - let data = T::parse_bytes(bytes)?; - Ok(Self { size, data }) + T::parse_bytes(bytes).map(Self::new) } } @@ -123,7 +159,7 @@ impl<'b, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed { } let (data, rest) = rest.split_at(size.get() as usize); let data = T::parse_bytes(data)?; - Ok((Self { size, data }, rest)) + Ok((Self::new(data), rest)) } } @@ -209,6 +245,25 @@ unsafe impl SplitBytesByRef for SizePrefixed { } } +//--- Building into DNS messages + +impl BuildIntoMessage for SizePrefixed { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + assert_eq!(builder.appended(), &[] as &[u8]); + builder.append_bytes(&0u16.to_be_bytes())?; + self.data.build_into_message(builder.delegate())?; + let size = builder.appended().len() - 2; + let size = u16::try_from(size).expect("the data never exceeds 64KiB"); + // SAFETY: A 'U16' is being modified, not a domain name. + let size_buf = unsafe { &mut builder.appended_mut()[0..2] }; + size_buf.copy_from_slice(&size.to_be_bytes()); + Ok(builder.commit()) + } +} + //--- Building into byte strings impl BuildBytes for SizePrefixed { From 95d0fe898b8600b0ec10abd7ab24a77aa71be134 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 17:54:13 +0100 Subject: [PATCH 081/111] [new_base/message] Add 'as_bytes_mut()' and 'slice_to_mut()' --- src/new_base/message.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index ab1aa903c..734a11051 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -19,6 +19,27 @@ pub struct Message { pub contents: [u8], } +//--- Inspection + +impl Message { + /// Represent this as a mutable byte sequence. + /// + /// Given `&mut self`, it is already possible to individually modify the + /// message header and contents; since neither has invalid instances, it + /// is safe to represent the entire object as mutable bytes. + pub fn as_bytes_mut(&mut self) -> &mut [u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts_mut( + self as *mut Self as *mut u8, + core::mem::size_of_val(self), + ) + } + } +} + //--- Interaction impl Message { @@ -30,6 +51,15 @@ impl Message { Self::parse_bytes_by_ref(bytes) .expect("A 12-or-more byte string is a valid 'Message'") } + + /// Truncate the contents of this message to the given size, mutably. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to_mut(&mut self, size: usize) -> &mut Self { + let bytes = &mut self.as_bytes_mut()[..12 + size]; + Self::parse_bytes_by_mut(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } } //----------- Header --------------------------------------------------------- From 3f14ccae8b875947e0d1ad6f5e7aa6243162bce4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 12:07:12 +0100 Subject: [PATCH 082/111] [new_base/name] Add 'UnparsedName' --- src/new_base/name/mod.rs | 3 + src/new_base/name/reversed.rs | 60 +++++++------ src/new_base/name/unparsed.rs | 162 ++++++++++++++++++++++++++++++++++ 3 files changed, 196 insertions(+), 29 deletions(-) create mode 100644 src/new_base/name/unparsed.rs diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 9270f4d5c..ca9f3e581 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -19,3 +19,6 @@ pub use label::{Label, LabelIter}; mod reversed; pub use reversed::{RevName, RevNameBuf}; + +mod unparsed; +pub use unparsed::UnparsedName; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index aa58b24cf..d022352be 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -1,4 +1,4 @@ -//! Reversed DNS names. +//! Reversed domain names. use core::{ borrow::Borrow, @@ -312,6 +312,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; @@ -331,38 +332,39 @@ fn parse_segment<'a>( buffer: &mut RevNameBuf, ) -> Result<(Option, &'a [u8]), ParseError> { loop { - let (&length, rest) = bytes.split_first().ok_or(ParseError)?; - if length == 0 { - // Found the root, stop. - buffer.prepend(&[0u8]); - return Ok((None, rest)); - } else if length < 64 { - // This looks like a regular label. - - if rest.len() < length as usize { - // The input doesn't contain the whole label. - return Err(ParseError); - } else if buffer.offset < 2 + length { - // The output name would exceed 254 bytes (this isn't - // the root label, so it can't fill the 255th byte). - return Err(ParseError); + match bytes { + &[0, ref rest @ ..] => { + // Found the root, stop. + buffer.prepend(&[0u8]); + return Ok((None, rest)); } - let (label, rest) = bytes.split_at(1 + length as usize); - buffer.prepend(label); - bytes = rest; - } else if length >= 0xC0 { - // This looks like a compression pointer. + &[l, ..] if l < 64 => { + // This looks like a regular label. + + if bytes.len() < 1 + l as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if buffer.offset < 2 + l { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + l as usize); + buffer.prepend(label); + bytes = rest; + } - let (&extra, rest) = rest.split_first().ok_or(ParseError)?; - let pointer = u16::from_be_bytes([length, extra]); + &[hi, lo, ref rest @ ..] if hi >= 0xC0 => { + let pointer = u16::from_be_bytes([hi, lo]); - // NOTE: We don't verify the pointer here, that's left to - // the caller (since they have to actually use it). - return Ok((Some(pointer & 0x3FFF), rest)); - } else { - // This is an invalid or deprecated label type. - return Err(ParseError); + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } + + _ => return Err(ParseError), } } } diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs new file mode 100644 index 000000000..e437ddd83 --- /dev/null +++ b/src/new_base/name/unparsed.rs @@ -0,0 +1,162 @@ +//! Unparsed domain names. + +use domain_macros::*; + +use crate::new_base::{ + parse::{ParseFromMessage, SplitFromMessage}, + wire::ParseError, + Message, +}; + +//----------- UnparsedName --------------------------------------------------- + +/// An unparsed domain name in a DNS message. +/// +/// Within a DNS message, domain names are stored in conventional order (from +/// innermost to the root label), and may end with a compression pointer. An +/// [`UnparsedName`] represents this incomplete domain name, exactly as stored +/// in a message. +#[derive(AsBytes)] +#[repr(transparent)] +pub struct UnparsedName([u8]); + +//--- Constants + +impl UnparsedName { + /// The maximum size of an unparsed domain name. + /// + /// A domain name can be 255 bytes at most, but an unparsed domain name + /// could replace the last byte (representing the root label) with a + /// compression pointer to it. Since compression pointers are 2 bytes, + /// the total size becomes 256 bytes. + pub const MAX_SIZE: usize = 256; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl UnparsedName { + /// Assume a byte string is a valid [`UnparsedName`]. + /// + /// # Safety + /// + /// The byte string must contain any number of encoded labels, ending with + /// a root label or a compression pointer, as long as the size of the + /// whole string is 256 bytes or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'UnparsedName' is 'repr(transparent)' to '[u8]', so casting + // a '[u8]' into an 'UnparsedName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl UnparsedName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// A byte representation of the [`UnparsedName`]. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a UnparsedName { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = message.contents.get(start..).ok_or(ParseError)?; + let mut offset = 0; + let offset = loop { + match &bytes[offset..] { + // This is the root label. + &[0, ..] => break offset + 1, + + // This looks like a regular label. + &[l, ref rest @ ..] if (1..64).contains(&l) => { + let length = l as usize; + + if rest.len() < length || offset + 2 + length > 255 { + // The name is incomplete or too big. + return Err(ParseError); + } + + offset += 1 + length; + } + + // This is a compression pointer. + &[hi, lo, ..] if hi >= 0xC0 => { + let ptr = u16::from_be_bytes([hi, lo]); + if usize::from(ptr - 0xC000) >= start { + return Err(ParseError); + } + break offset + 2; + } + + _ => return Err(ParseError), + } + }; + + let bytes = &bytes[..offset]; + let rest = start + offset; + Ok((unsafe { UnparsedName::from_bytes_unchecked(bytes) }, rest)) + } +} + +impl<'a> ParseFromMessage<'a> for &'a UnparsedName { + fn parse_from_message( + message: &'a Message, + start: usize, + ) -> Result { + let bytes = message.contents.get(start..).ok_or(ParseError)?; + let mut offset = 0; + loop { + match &bytes[offset..] { + // This is the root label. + &[0] => break, + + // This looks like a regular label. + &[l, ref rest @ ..] if (1..64).contains(&l) => { + let length = l as usize; + + if rest.len() < length || offset + 2 + length > 255 { + // The name is incomplete or too big. + return Err(ParseError); + } + + offset += 1 + length; + } + + // This is a compression pointer. + &[hi, lo] if hi >= 0xC0 => { + let ptr = u16::from_be_bytes([hi, lo]); + if usize::from(ptr - 0xC000) >= start { + return Err(ParseError); + } + break; + } + + _ => return Err(ParseError), + } + } + + Ok(unsafe { UnparsedName::from_bytes_unchecked(bytes) }) + } +} From 2c9594efde7310026af45905a0adc2663e5102b5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:21:01 +0100 Subject: [PATCH 083/111] [new_base/message] Impl 'as_array()' for 'SectionCounts' --- src/new_base/message.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 734a11051..d900b4a1c 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -295,6 +295,22 @@ pub struct SectionCounts { pub additional: U16, } +//--- Interaction + +impl SectionCounts { + /// Represent these counts as an array. + pub fn as_array(&self) -> &[U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } + + /// Represent these counts as a mutable array. + pub fn as_array_mut(&mut self) -> &mut [U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } +} + //--- Formatting impl fmt::Display for SectionCounts { From a01a7f24092cb8fc625a39b73e2813115434c88a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:21:16 +0100 Subject: [PATCH 084/111] [new_base/build] Define 'RecordBuilder' --- src/new_base/build/message.rs | 114 ++++++++++++++--------- src/new_base/build/mod.rs | 3 + src/new_base/build/record.rs | 166 ++++++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+), 41 deletions(-) create mode 100644 src/new_base/build/record.rs diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index fd8fa3b45..a79928726 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -1,12 +1,13 @@ //! Building whole DNS messages. -//----------- MessageBuilder ------------------------------------------------- - use crate::new_base::{ - wire::TruncationError, Header, Message, Question, Record, + wire::TruncationError, Header, Message, Question, RClass, RType, Record, + TTL, }; -use super::{BuildIntoMessage, Builder, BuilderContext}; +use super::{BuildIntoMessage, Builder, BuilderContext, RecordBuilder}; + +//----------- MessageBuilder ------------------------------------------------- /// A builder for a whole DNS message. /// @@ -73,14 +74,6 @@ impl<'b> MessageBuilder<'b> { self.inner.header_mut() } - /// Uninitialized space in the message buffer. - /// - /// This can be filled manually, then marked as initialized using - /// [`Self::mark_appended()`]. - pub fn uninitialized(&mut self) -> &mut [u8] { - self.inner.uninitialized() - } - /// The message built thus far. pub fn message(&self) -> &Message { self.inner.cur_message() @@ -107,7 +100,7 @@ impl<'b> MessageBuilder<'b> { /// /// This returns the message buffer and the context for this builder. The /// two are linked, and the builder can be recomposed with - /// [`Self::raw_from_parts()`]. + /// [`Self::from_raw_parts()`]. pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { let (mut message, context, _commit) = self.inner.into_raw_parts(); // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed @@ -121,20 +114,6 @@ impl<'b> MessageBuilder<'b> { //--- Interaction impl MessageBuilder<'_> { - /// Mark bytes in the buffer as initialized. - /// - /// The given number of bytes from the beginning of - /// [`Self::uninitialized()`] will be marked as initialized, and will be - /// treated as appended content in the buffer. - /// - /// # Panics - /// - /// Panics if the uninitialized buffer is smaller than the given number of - /// initialized bytes. - pub fn mark_appended(&mut self, amount: usize) { - self.inner.mark_appended(amount) - } - /// Limit the total message size. /// /// The message will not be allowed to exceed the given size, in bytes. @@ -167,18 +146,36 @@ impl MessageBuilder<'_> { N: BuildIntoMessage, { // Ensure there are no records present. - let header = self.header(); - let records = header.counts.answers - + header.counts.authorities - + header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[1..], [0, 0, 0]); question.build_into_message(self.inner.delegate())?; - self.header_mut().counts.questions += 1; Ok(()) } + /// Build an arbitrary record. + /// + /// The record will be added to the specified section (1, 2, or 3, i.e. + /// answers, authorities, and additional records respectively). There + /// must not be any existing records in sections after this one. + pub fn build_record( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + section: u8, + ) -> Result, TruncationError> { + RecordBuilder::new( + self.inner.delegate(), + rname, + rtype, + rclass, + ttl, + section, + ) + } + /// Append an answer record. /// /// # Panics @@ -193,16 +190,28 @@ impl MessageBuilder<'_> { D: BuildIntoMessage, { // Ensure there are no authority or additional records present. - let header = self.header(); - let records = header.counts.authorities + header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[2..], [0, 0]); record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.answers += 1; Ok(()) } + /// Build an answer record. + /// + /// # Panics + /// + /// Panics if the message contains any authority or additional records. + pub fn build_answer( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 1) + } + /// Append an authority record. /// /// # Panics @@ -217,16 +226,28 @@ impl MessageBuilder<'_> { D: BuildIntoMessage, { // Ensure there are no additional records present. - let header = self.header(); - let records = header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[3..], [0]); record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.authorities += 1; Ok(()) } + /// Build an authority record. + /// + /// # Panics + /// + /// Panics if the message contains any additional records. + pub fn build_authority( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 2) + } + /// Append an additional record. pub fn append_additional( &mut self, @@ -240,4 +261,15 @@ impl MessageBuilder<'_> { self.header_mut().counts.additional += 1; Ok(()) } + + /// Build an additional record. + pub fn build_additional( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 3) + } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 80b2b0942..7b1598ede 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -6,6 +6,9 @@ pub use builder::{Builder, BuilderContext}; mod message; pub use message::MessageBuilder; +mod record; +pub use record::RecordBuilder; + use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs new file mode 100644 index 000000000..873628bda --- /dev/null +++ b/src/new_base/build/record.rs @@ -0,0 +1,166 @@ +//! Building DNS records. + +use crate::new_base::{ + name::RevName, + wire::{AsBytes, TruncationError}, + Header, Message, RClass, RType, TTL, +}; + +use super::{BuildCommitted, BuildIntoMessage, Builder}; + +//----------- RecordBuilder -------------------------------------------------- + +/// A builder for a DNS record. +/// +/// This is used to incrementally build the data for a DNS record. It can be +/// constructed using [`MessageBuilder::build_answer()`] etc. +/// +/// [`MessageBuilder::build_answer()`]: super::MessageBuilder::build_answer() +pub struct RecordBuilder<'b> { + /// The underlying [`Builder`]. + /// + /// Its commit point lies at the beginning of the record. + inner: Builder<'b>, + + /// The position of the record data. + /// + /// This is an offset from the message contents. + start: usize, + + /// The section the record is a part of. + /// + /// The appropriate section count will be incremented on completion. + section: u8, +} + +//--- Initialization + +impl<'b> RecordBuilder<'b> { + /// Construct a [`RecordBuilder`] from raw parts. + /// + /// # Safety + /// + /// - `builder`, `start`, and `section` are paired together. + pub unsafe fn from_raw_parts( + builder: Builder<'b>, + start: usize, + section: u8, + ) -> Self { + Self { + inner: builder, + start, + section, + } + } + + /// Initialize a new [`RecordBuilder`]. + /// + /// A new record with the given name, type, and class will be created. + /// The returned builder can be used to add data for the record. + /// + /// The count for the specified section (1, 2, or 3, i.e. answers, + /// authorities, and additional records respectively) will be incremented + /// when the builder finishes successfully. + pub fn new( + mut builder: Builder<'b>, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + section: u8, + ) -> Result { + debug_assert_eq!(builder.appended(), &[] as &[u8]); + debug_assert!((1..4).contains(§ion)); + + assert!(builder + .header() + .counts + .as_array() + .iter() + .skip(1 + section as usize) + .all(|&c| c == 0)); + + // Build the record header. + rname.build_into_message(builder.delegate())?; + builder.append_bytes(rtype.as_bytes())?; + builder.append_bytes(rclass.as_bytes())?; + builder.append_bytes(ttl.as_bytes())?; + let start = builder.appended().len(); + + // Set up the builder. + Ok(Self { + inner: builder, + start, + section, + }) + } +} + +//--- Inspection + +impl<'b> RecordBuilder<'b> { + /// The message header. + pub fn header(&self) -> &Header { + self.inner.header() + } + + /// The message without this record. + pub fn message(&self) -> &Message { + self.inner.message() + } + + /// The record data appended thus far. + pub fn data(&self) -> &[u8] { + &self.inner.appended()[self.start..] + } + + /// Decompose this builder into raw parts. + /// + /// This returns the underlying builder, the offset of the record data in + /// the record, and the section number for this record (1, 2, or 3). The + /// builder can be recomposed with [`Self::from_raw_parts()`]. + pub fn into_raw_parts(self) -> (Builder<'b>, usize, u8) { + (self.inner, self.start, self.section) + } +} + +//--- Interaction + +impl RecordBuilder<'_> { + /// Finish the record. + /// + /// The respective section count will be incremented. The builder will be + /// consumed and the record will be committed. + pub fn finish(mut self) -> BuildCommitted { + // Increment the appropriate section count. + self.inner.header_mut().counts.as_array_mut() + [self.section as usize] += 1; + + self.inner.commit() + } + + /// Delegate to a new builder. + /// + /// Any content committed by the builder will be added as record data. + pub fn delegate(&mut self) -> Builder<'_> { + self.inner.delegate() + } + + /// Append some bytes. + /// + /// No name compression will be performed. + pub fn append_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), TruncationError> { + self.inner.append_bytes(bytes) + } + + /// Compress and append a domain name. + pub fn append_name( + &mut self, + name: &RevName, + ) -> Result<(), TruncationError> { + self.inner.append_name(name) + } +} From 7ef9b5b63c122caccb0d0fd20453cde976950fe9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:24:23 +0100 Subject: [PATCH 085/111] [new_base/name] Accept Clippy simplifications --- src/new_base/name/reversed.rs | 8 ++++---- src/new_base/name/unparsed.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index d022352be..f33451f3b 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -332,14 +332,14 @@ fn parse_segment<'a>( buffer: &mut RevNameBuf, ) -> Result<(Option, &'a [u8]), ParseError> { loop { - match bytes { - &[0, ref rest @ ..] => { + match *bytes { + [0, ref rest @ ..] => { // Found the root, stop. buffer.prepend(&[0u8]); return Ok((None, rest)); } - &[l, ..] if l < 64 => { + [l, ..] if l < 64 => { // This looks like a regular label. if bytes.len() < 1 + l as usize { @@ -356,7 +356,7 @@ fn parse_segment<'a>( bytes = rest; } - &[hi, lo, ref rest @ ..] if hi >= 0xC0 => { + [hi, lo, ref rest @ ..] if hi >= 0xC0 => { let pointer = u16::from_be_bytes([hi, lo]); // NOTE: We don't verify the pointer here, that's left to diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs index e437ddd83..828c92229 100644 --- a/src/new_base/name/unparsed.rs +++ b/src/new_base/name/unparsed.rs @@ -85,12 +85,12 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { let bytes = message.contents.get(start..).ok_or(ParseError)?; let mut offset = 0; let offset = loop { - match &bytes[offset..] { + match bytes[offset..] { // This is the root label. - &[0, ..] => break offset + 1, + [0, ..] => break offset + 1, // This looks like a regular label. - &[l, ref rest @ ..] if (1..64).contains(&l) => { + [l, ref rest @ ..] if (1..64).contains(&l) => { let length = l as usize; if rest.len() < length || offset + 2 + length > 255 { @@ -102,7 +102,7 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { } // This is a compression pointer. - &[hi, lo, ..] if hi >= 0xC0 => { + [hi, lo, ..] if hi >= 0xC0 => { let ptr = u16::from_be_bytes([hi, lo]); if usize::from(ptr - 0xC000) >= start { return Err(ParseError); @@ -128,12 +128,12 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedName { let bytes = message.contents.get(start..).ok_or(ParseError)?; let mut offset = 0; loop { - match &bytes[offset..] { + match bytes[offset..] { // This is the root label. - &[0] => break, + [0] => break, // This looks like a regular label. - &[l, ref rest @ ..] if (1..64).contains(&l) => { + [l, ref rest @ ..] if (1..64).contains(&l) => { let length = l as usize; if rest.len() < length || offset + 2 + length > 255 { @@ -145,7 +145,7 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedName { } // This is a compression pointer. - &[hi, lo] if hi >= 0xC0 => { + [hi, lo] if hi >= 0xC0 => { let ptr = u16::from_be_bytes([hi, lo]); if usize::from(ptr - 0xC000) >= start { return Err(ParseError); From bfa8c5cb37494043497a8f59b3e2b03304360b48 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 17:00:47 +0100 Subject: [PATCH 086/111] [new_base/build/record] Track record data size --- src/new_base/build/record.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 873628bda..aac0857c3 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -85,6 +85,7 @@ impl<'b> RecordBuilder<'b> { builder.append_bytes(rtype.as_bytes())?; builder.append_bytes(rclass.as_bytes())?; builder.append_bytes(ttl.as_bytes())?; + builder.append_bytes(&0u16.to_be_bytes())?; let start = builder.appended().len(); // Set up the builder. @@ -136,6 +137,15 @@ impl RecordBuilder<'_> { self.inner.header_mut().counts.as_array_mut() [self.section as usize] += 1; + // Set the record data length. + let size = self.inner.appended().len() - self.start; + let size = u16::try_from(size) + .expect("Record data must be smaller than 64KiB"); + // SAFETY: The record data size is not part of a compressed name. + let appended = unsafe { self.inner.appended_mut() }; + appended[self.start - 2..self.start] + .copy_from_slice(&size.to_be_bytes()); + self.inner.commit() } From 07cf2deed00d5811b0c1b964d73da0d49727b07b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 16 Jan 2025 12:38:44 +0100 Subject: [PATCH 087/111] [new_base/parse] Remove 'ParseMessage' etc. These interfaces need to be redesigned to be more specific to important use cases. --- src/new_base/parse/message.rs | 49 ----------- src/new_base/parse/mod.rs | 9 -- src/new_base/parse/question.rs | 148 --------------------------------- src/new_base/parse/record.rs | 148 --------------------------------- 4 files changed, 354 deletions(-) delete mode 100644 src/new_base/parse/message.rs delete mode 100644 src/new_base/parse/question.rs delete mode 100644 src/new_base/parse/record.rs diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs deleted file mode 100644 index 1c964588a..000000000 --- a/src/new_base/parse/message.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! Parsing DNS messages. - -use core::ops::ControlFlow; - -use crate::new_base::{Header, UnparsedQuestion, UnparsedRecord}; - -/// A type that can be constructed by parsing a DNS message. -pub trait ParseMessage<'a>: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: VisitMessagePart<'a>; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? - type Error; - - /// Construct a visitor, providing the message header. - fn make_visitor(header: &'a Header) - -> Result; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit the components of a DNS message. -pub trait VisitMessagePart<'a> { - /// The type of errors produced by visits. - type Error; - - /// Visit a component of the message. - fn visit( - &mut self, - component: MessagePart<'a>, - ) -> Result, Self::Error>; -} - -/// A component of a DNS message. -pub enum MessagePart<'a> { - /// A question. - Question(&'a UnparsedQuestion), - - /// An answer record. - Answer(&'a UnparsedRecord<'a>), - - /// An authority record. - Authority(&'a UnparsedRecord<'a>), - - /// An additional record. - Additional(&'a UnparsedRecord<'a>), -} diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 7e5a08d7f..e6d47f4f0 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,14 +1,5 @@ //! Parsing DNS messages from the wire format. -mod message; -pub use message::{MessagePart, ParseMessage, VisitMessagePart}; - -mod question; -pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; - -mod record; -pub use record::{ParseRecord, ParseRecords, VisitRecord}; - pub use super::wire::ParseError; use super::{ diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs deleted file mode 100644 index 784cadc09..000000000 --- a/src/new_base/parse/question.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! Parsing DNS questions. - -use core::{convert::Infallible, ops::ControlFlow}; - -#[cfg(feature = "std")] -use std::boxed::Box; -#[cfg(feature = "std")] -use std::vec::Vec; - -use crate::new_base::UnparsedQuestion; - -//----------- Trait definitions ---------------------------------------------- - -/// A type that can be constructed by parsing exactly one DNS question. -pub trait ParseQuestion: Sized { - /// The type of parse errors. - // TODO: Remove entirely? - type Error; - - /// Parse the given DNS question. - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error>; -} - -/// A type that can be constructed by parsing zero or more DNS questions. -pub trait ParseQuestions: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitQuestion; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? Or remove entirely? - type Error; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit DNS questions. -pub trait VisitQuestion { - /// The type of errors produced by visits. - type Error; - - /// Visit a question. - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error>; -} - -//----------- Trait implementations ------------------------------------------ - -impl ParseQuestion for UnparsedQuestion { - type Error = Infallible; - - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(ControlFlow::Break(question.clone())) - } -} - -//--- Impls for 'Option' - -impl ParseQuestion for Option { - type Error = T::Error; - - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -impl ParseQuestions for Option { - type Visitor = Option; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -impl VisitQuestion for Option { - type Error = T::Error; - - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - if self.is_some() { - return Ok(ControlFlow::Continue(())); - } - - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => { - *self = Some(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Vec' - -#[cfg(feature = "std")] -impl ParseQuestions for Vec { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -#[cfg(feature = "std")] -impl VisitQuestion for Vec { - type Error = T::Error; - - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => { - self.push(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Box<[T]>' - -#[cfg(feature = "std")] -impl ParseQuestions for Box<[T]> { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor.into_boxed_slice()) - } -} diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs deleted file mode 100644 index 75e98a36a..000000000 --- a/src/new_base/parse/record.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! Parsing DNS records. - -use core::{convert::Infallible, ops::ControlFlow}; - -#[cfg(feature = "std")] -use std::boxed::Box; -#[cfg(feature = "std")] -use std::vec::Vec; - -use crate::new_base::UnparsedRecord; - -//----------- Trait definitions ---------------------------------------------- - -/// A type that can be constructed by parsing exactly one DNS record. -pub trait ParseRecord<'a>: Sized { - /// The type of parse errors. - // TODO: Remove entirely? - type Error; - - /// Parse the given DNS record. - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error>; -} - -/// A type that can be constructed by parsing zero or more DNS records. -pub trait ParseRecords<'a>: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitRecord<'a>; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? Or remove entirely? - type Error; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit DNS records. -pub trait VisitRecord<'a> { - /// The type of errors produced by visits. - type Error; - - /// Visit a record. - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error>; -} - -//----------- Trait implementations ------------------------------------------ - -impl<'a> ParseRecord<'a> for UnparsedRecord<'a> { - type Error = Infallible; - - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(ControlFlow::Break(record.clone())) - } -} - -//--- Impls for 'Option' - -impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { - type Error = T::Error; - - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Option { - type Visitor = Option; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { - type Error = T::Error; - - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - if self.is_some() { - return Ok(ControlFlow::Continue(())); - } - - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => { - *self = Some(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Vec' - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Vec { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { - type Error = T::Error; - - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => { - self.push(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Box<[T]>' - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Box<[T]> { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor.into_boxed_slice()) - } -} From 477252047ef16a64c25c1f5426d02db863eae53f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 16 Jan 2025 12:47:08 +0100 Subject: [PATCH 088/111] [new_base/build] Document 'BuildCommitted' thoroughly --- src/new_base/build/mod.rs | 54 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 7b1598ede..38ab501fe 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -38,13 +38,63 @@ impl BuildIntoMessage for [u8] { //----------- BuildResult ---------------------------------------------------- /// The result of building into a DNS message. +/// +/// This is used in [`BuildIntoMessage::build_into_message()`]. pub type BuildResult = Result; //----------- BuildCommitted ------------------------------------------------- /// The output of [`Builder::commit()`]. /// -/// This is a stub type to remind users to call [`Builder::commit()`] in all -/// success paths of building functions. +/// This is a simple marker type, produced by [`Builder::commit()`]. Certain +/// trait methods (e.g. [`BuildIntoMessage::build_into_message()`]) require it +/// in the return type, as a way to remind users to commit their builders. +/// +/// # Examples +/// +/// If `build_into_message()` simply returned a unit type, an example impl may +/// look like: +/// +/// ```compile_fail +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(()) +/// } +/// } +/// ``` +/// +/// This code is incorrect: since the appended content is not committed, the +/// builder will remove it when it is dropped (at the end of the function), +/// and so nothing gets written. Instead, users have to write: +/// +/// ``` +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(builder.commit()) +/// } +/// } +/// ``` #[derive(Debug)] pub struct BuildCommitted; From 554bb71462da71d1098b92659d84314e090dd8ca Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 14:55:10 +0100 Subject: [PATCH 089/111] [new_base/build/builder] Rewrite with a lot of documentation --- src/new_base/build/builder.rs | 430 ++++++++++++++++++++++++---------- src/new_base/message.rs | 18 ++ 2 files changed, 329 insertions(+), 119 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 274bbd7c6..16444349c 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -8,7 +8,7 @@ use core::{ use crate::new_base::{ name::RevName, - wire::{AsBytes, BuildBytes, ParseBytesByRef, TruncationError}, + wire::{BuildBytes, ParseBytesByRef, TruncationError}, Header, Message, }; @@ -16,7 +16,86 @@ use super::BuildCommitted; //----------- Builder -------------------------------------------------------- -/// A DNS message builder. +/// A DNS wire format serializer. +/// +/// This can be used to write arbitrary bytes and (compressed) domain names to +/// a buffer containing a DNS message. It is a low-level interface, providing +/// the foundations for high-level builder types. +/// +/// In order to build a regular DNS message, users would typically look to +/// [`MessageBuilder`](super::MessageBuilder). This offers the high-level +/// interface (with methods to append questions and records) that most users +/// need. +/// +/// # Committing and Delegation +/// +/// [`Builder`] provides an "atomic" interface: if a function fails while +/// building a DNS message using a [`Builder`], any partial content added by +/// the [`Builder`] will be reverted. The content of a [`Builder`] is only +/// confirmed when [`Builder::commit()`] is called. +/// +/// It is useful to first describe what "building functions" look like. While +/// they may take additional arguments, their signatures are usually: +/// +/// ```no_run +/// # use domain::new_base::build::{Builder, BuildResult}; +/// +/// fn foo(mut builder: Builder<'_>) -> BuildResult { +/// // Append to the message using 'builder'. +/// +/// // Commit all appended content and return successfully. +/// Ok(builder.commit()) +/// } +/// ``` +/// +/// Note that the builder is taken by value; if an error occurs, and the +/// function returns early, `builder` will be dropped, and its drop code will +/// revert all uncommitted changes. However, if building is successful, the +/// appended content is committed, and so will not be reverted. +/// +/// If `foo` were to call another function with the same signature, it would +/// need to create a new [`Builder`] to pass in by value. This [`Builder`] +/// should refer to the same message buffer, but should have not report any +/// uncommitted content (so that only the content added by the called function +/// will be reverted on failure). For this, we have [`delegate()`]. +/// +/// [`delegate()`]: Self::delegate() +/// +/// For example: +/// +/// ``` +/// # use domain::new_base::build::{Builder, BuildResult, BuilderContext}; +/// +/// /// A build function with the conventional type signature. +/// fn foo(mut builder: Builder<'_>) -> BuildResult { +/// // Content added by the parent builder is considered committed. +/// assert_eq!(builder.committed(), b"hi! "); +/// +/// // Append some content to the builder. +/// builder.append_bytes(b"foo!")?; +/// +/// // Try appending a very long string, which can't fit. +/// builder.append_bytes(b"helloworldthisiswaytoobig")?; +/// +/// Ok(builder.commit()) +/// } +/// +/// // Construct a builder for a particular buffer. +/// let mut buffer = [0u8; 20]; +/// let mut context = BuilderContext::default(); +/// let mut builder = Builder::new(&mut buffer, &mut context); +/// +/// // Try appending some content to the builder. +/// builder.append_bytes(b"hi! ").unwrap(); +/// assert_eq!(builder.appended(), b"hi! "); +/// +/// // Try calling 'foo' -- note that it will fail. +/// // Note that we delegated the builder. +/// foo(builder.delegate()).unwrap_err(); +/// +/// // No partial content was written. +/// assert_eq!(builder.appended(), b"hi! "); +/// ``` pub struct Builder<'b> { /// The message being built. /// @@ -41,22 +120,74 @@ pub struct Builder<'b> { commit: usize, } -//--- Initialization - +/// # Initialization +/// +/// In order to begin building a DNS message: +/// +/// ``` +/// # use domain::new_base::build::{Builder, BuilderContext}; +/// +/// // Allocate a slice of 'u8's somewhere. +/// let mut buffer = [0u8; 20]; +/// +/// // Obtain a builder context. +/// // +/// // The value doesn't matter, it will be overwritten. +/// let mut context = BuilderContext::default(); +/// +/// // Construct the actual 'Builder'. +/// let builder = Builder::new(&mut buffer, &mut context); +/// +/// assert!(builder.committed().is_empty()); +/// assert!(builder.appended().is_empty()); +/// ``` impl<'b> Builder<'b> { + /// Create a [`Builder`] for a new, empty DNS message. + /// + /// The message header is left uninitialized. Use [`Self::header_mut()`] + /// to initialize it. The message contents are completely empty. + /// + /// The provided builder context will be overwritten with a default state. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + let message = Message::parse_bytes_by_mut(buffer) + .expect("The buffure must be at least 12 bytes in size"); + context.size = 0; + + // SAFETY: 'message' and 'context' are now consistent. + unsafe { Self::from_raw_parts(message.into(), context, 0) } + } + /// Construct a [`Builder`] from raw parts. /// + /// The provided components must originate from [`into_raw_parts()`], and + /// none of the components can be modified since they were extracted. + /// + /// [`into_raw_parts()`]: Self::into_raw_parts() + /// + /// This method is useful when overcoming limitations in lifetimes or + /// borrow checking, or when a builder has to be constructed from another + /// with specific characteristics. + /// /// # Safety /// + /// The expression `from_raw_parts(message, context, commit)` is sound if + /// and only if all of the following conditions are satisfied: + /// /// - `message` is a valid reference for the lifetime `'b`. /// - `message.header` is mutably borrowed for `'b`. /// - `message.contents[..commit]` is immutably borrowed for `'b`. /// - `message.contents[commit..]` is mutably borrowed for `'b`. /// - /// - `message` and `context` are paired together. - /// - /// - `commit` is at most `context.size()`, which is at most - /// `context.max_size()`. + /// - `message` and `context` originate from the same builder. + /// - `commit <= context.size() <= message.contents.len()`. pub unsafe fn from_raw_parts( message: NonNull, context: &'b mut BuilderContext, @@ -69,53 +200,82 @@ impl<'b> Builder<'b> { commit, } } - - /// Initialize an empty [`Builder`]. - /// - /// The message header is left uninitialized. Use [`Self::header_mut()`] - /// to initialize it. - /// - /// # Panics - /// - /// Panics if the buffer is less than 12 bytes long (which is the minimum - /// possible size for a DNS message). - pub fn new( - buffer: &'b mut [u8], - context: &'b mut BuilderContext, - ) -> Self { - assert!(buffer.len() >= 12); - let message = Message::parse_bytes_by_mut(buffer) - .expect("A 'Message' can fit in 12 bytes"); - context.size = 0; - context.max_size = message.contents.len(); - - // SAFETY: 'message' and 'context' are now consistent. - unsafe { Self::from_raw_parts(message.into(), context, 0) } - } } -//--- Inspection - +/// # Inspection +/// +/// A [`Builder`] references a message buffer to write into. That buffer is +/// broken down into the following segments: +/// +/// ```text +/// name | position +/// --------------+--------- +/// header | +/// committed | 0 .. commit +/// appended | commit .. size +/// uninitialized | size .. limit +/// inaccessible | limit .. +/// ``` +/// +/// The DNS message header can be modified at any time. It is made available +/// through [`header()`] and [`header_mut()`]. In general, it is inadvisable +/// to change the section counts arbitrarily (although it will not cause +/// undefined behaviour). +/// +/// [`header()`]: Self::header() +/// [`header_mut()`]: Self::header_mut() +/// +/// The committed content of the builder is immutable, and is available to +/// reference, through [`committed()`], for the lifetime `'b`. +/// +/// [`committed()`]: Self::committed() +/// +/// The appended content of the builder is made available via [`appended()`]. +/// It is content that has been added by this builder, but that has not yet +/// been committed. When the [`Builder`] is dropped, this content is removed +/// (it becomes uninitialized). Appended content can be modified, but any +/// compressed names within it have to be handled with great care; they can +/// only be modified by removing them entirely (by rewinding the builder, +/// using [`rewind()`]) and building them again. When compressed names are +/// guaranteed to not be modified, [`appended_mut()`] can be used. +/// +/// [`appended()`]: Self::appended() +/// [`rewind()`]: Self::rewind() +/// [`appended_mut()`]: Self::appended_mut() +/// +/// The uninitialized space in the builder will be written to when appending +/// new content. It can be accessed directly, in case that is more efficient +/// for building, using [`uninitialized()`]. [`mark_appended()`] can be used +/// to specify how many bytes were initialized. +/// +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() +/// +/// The inaccessible space of a builder cannot be written to. While it exists +/// in the underlying message buffer, it has been made inaccessible so that +/// the built message fits within certain size constraints. A message's size +/// can be limited using [`limit_to()`], but this only applies to the current +/// builder (and its delegates); parent builders are unaffected by it. +/// +/// [`limit_to()`]: Self::limit_to() impl<'b> Builder<'b> { - /// The message header. - /// - /// The header can be modified by the builder, and so is only available - /// for a short lifetime. Note that it implements [`Copy`]. + /// The header of the DNS message. pub fn header(&self) -> &Header { // SAFETY: 'message.header' is mutably borrowed by 'self'. unsafe { &(*self.message.as_ptr()).header } } - /// Mutable access to the message header. + /// The header of the DNS message, mutably. + /// + /// It is possible to modify the section counts arbitrarily through this + /// method; while doing so cannot cause undefined behaviour, it is not + /// recommended. pub fn header_mut(&mut self) -> &mut Header { // SAFETY: 'message.header' is mutably borrowed by 'self'. unsafe { &mut (*self.message.as_ptr()).header } } /// Committed message contents. - /// - /// The message contents are available for the lifetime `'b`; the builder - /// cannot be used to modify them since they have been committed. pub fn committed(&self) -> &'b [u8] { // SAFETY: 'message.contents[..commit]' is immutably borrowed by // 'self'. @@ -146,38 +306,31 @@ impl<'b> Builder<'b> { /// Uninitialized space in the message buffer. /// - /// This can be filled manually, then marked as initialized using - /// [`Self::mark_appended()`]. + /// When the first `n` bytes of the returned buffer are initialized, and + /// should be treated as appended content in the message, call + /// [`self.mark_appended(n)`](Self::mark_appended()). pub fn uninitialized(&mut self) -> &mut [u8] { // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.context.size..self.context.max_size; - unsafe { &mut (*self.message.as_ptr()).contents[range] } + unsafe { &mut (*self.message.as_ptr()).contents[self.context.size..] } } /// The message with all committed contents. /// /// The header of the message can be modified by the builder, so the /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`Self::committed()`]. + /// borrowed for a longer lifetime -- see [`committed()`]. The message + /// does not include content that has been appended but not committed. + /// + /// [`committed()`]: Self::committed() pub fn message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. - let message = unsafe { &*self.message.as_ptr() }; - let message = &message.as_bytes()[..12 + self.commit]; - Message::parse_bytes_by_ref(message) - .expect("'message' represents a valid 'Message'") + unsafe { self.message.as_ref() }.slice_to(self.commit) } /// The message including any uncommitted contents. - /// - /// The header of the message can be modified by the builder, so the - /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`Self::committed()`]. pub fn cur_message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. - let message = unsafe { &*self.message.as_ptr() }; - let message = &message.as_bytes()[..12 + self.context.size]; - Message::parse_bytes_by_ref(message) - .expect("'message' represents a valid 'Message'") + unsafe { self.message.as_ref() }.slice_to(self.context.size) } /// A pointer to the message, including any uncommitted contents. @@ -185,8 +338,11 @@ impl<'b> Builder<'b> { /// The first `commit` bytes of the message contents (also provided by /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. /// The remainder of the message is initialized and borrowed by `self`. - pub fn cur_message_ptr(&self) -> NonNull { - self.cur_message().into() + pub fn cur_message_ptr(&mut self) -> NonNull { + let message = self.message.as_ptr(); + let size = self.context.size; + let message = unsafe { Message::ptr_slice_to(message, size) }; + unsafe { NonNull::new_unchecked(message) } } /// The builder context. @@ -194,6 +350,29 @@ impl<'b> Builder<'b> { &*self.context } + /// The start point of this builder. + /// + /// This is the offset into the message contents at which this builder was + /// initialized. The content before this point has been committed and is + /// immutable. The builder can be rewound up to this point. + pub fn start(&self) -> usize { + self.commit + } + + /// The size limit of this builder. + /// + /// This is the maximum size the message contents can grow to; beyond it, + /// [`TruncationError`]s will occur. The limit can be tightened using + /// [`limit_to()`](Self::limit_to()). + pub fn max_size(&self) -> usize { + // SAFETY: 'Message' ends with a slice DST, and so references to it + // hold the length of that slice; we can cast it to another slice type + // and the pointer representation is unchanged. By using a slice type + // of ZST elements, aliasing is impossible, and it can be dereferenced + // safely. + unsafe { &*(self.message.as_ptr() as *mut [()]) }.len() + } + /// Decompose this builder into raw parts. /// /// This returns three components: @@ -221,24 +400,74 @@ impl<'b> Builder<'b> { } } -//--- Interaction - +/// # Interaction +/// +/// There are several ways to build up a DNS message using a [`Builder`]. +/// +/// When directly adding content, use [`append_bytes()`] or [`append_name()`]. +/// The former will add the bytes as-is, while the latter will compress domain +/// names. +/// +/// [`append_bytes()`]: Self::append_bytes() +/// [`append_name()`]: Self::append_name() +/// +/// When delegating to another builder method, use [`delegate()`]. This will +/// construct a new [`Builder`] that borrows from the current one. When the +/// method returns, the content it has committed will be registered as content +/// appended (but not committed) by the outer builder. If the method fails, +/// any content it tried to add will be removed automatically, and the outer +/// builder will be left unaffected. +/// +/// [`delegate()`]: Self::delegate() +/// +/// After all data is appended, call [`commit()`]. This will return a marker +/// type, [`BuildCommitted`], that may need to be returned to the caller. +/// +/// [`commit()`]: Self::commit() +/// +/// Some lower-level building methods are also available in the interest of +/// efficiency. Use [`append_with()`] if the amount of data to be written is +/// known upfront; it takes a closure to fill that space in the buffer. The +/// most general and efficient technique is to write into [`uninitialized()`] +/// and to mark the number of initialized bytes using [`mark_appended()`]. +/// +/// [`append_with()`]: Self::append_with() +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() impl Builder<'_> { - /// Rewind the builder, removing all committed content. + /// Rewind the builder, removing all uncommitted content. pub fn rewind(&mut self) { self.context.size = self.commit; } - /// Commit all appended content. + /// Commit the changes made by this builder. /// /// For convenience, a unit type [`BuildCommitted`] is returned; it is /// used as the return type of build functions to remind users to call /// this method on success paths. - pub fn commit(&mut self) -> BuildCommitted { + pub fn commit(mut self) -> BuildCommitted { + // Update 'commit' so that the drop glue is a no-op. self.commit = self.context.size; BuildCommitted } + /// Limit this builder to the given size. + /// + /// This builder, and all its delegates, will not allow the message + /// contents (i.e. excluding the 12-byte message header) to exceed the + /// specified size in bytes. If the message has already crossed that + /// limit, a [`TruncationError`] is returned. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if self.context.size <= size { + let message = self.message.as_ptr(); + let message = unsafe { Message::ptr_slice_to(message, size) }; + self.message = unsafe { NonNull::new_unchecked(message) }; + Ok(()) + } else { + Err(TruncationError) + } + } + /// Mark bytes in the buffer as initialized. /// /// The given number of bytes from the beginning of @@ -250,7 +479,7 @@ impl Builder<'_> { /// Panics if the uninitialized buffer is smaller than the given number of /// initialized bytes. pub fn mark_appended(&mut self, amount: usize) { - assert!(self.context.max_size - self.context.size >= amount); + assert!(self.max_size() - self.context.size >= amount); self.context.size += amount; } @@ -265,30 +494,6 @@ impl Builder<'_> { } } - /// Limit the total message size. - /// - /// The message will not be allowed to exceed the given size, in bytes. - /// Only the message header and contents are counted; the enclosing UDP - /// or TCP packet size is not considered. If the message already exceeds - /// this size, a [`TruncationError`] is returned. - /// - /// This size will apply to all builders for this message (including those - /// that delegated to `self`). It will not be automatically revoked if - /// message building fails. - /// - /// # Panics - /// - /// Panics if the given size is less than 12 bytes. - pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - assert!(size >= 12); - if self.context.size <= size - 12 { - self.context.max_size = size - 12; - Ok(()) - } else { - Err(TruncationError) - } - } - /// Append data of a known size using a closure. /// /// All the requested bytes must be initialized. If not enough free space @@ -336,40 +541,27 @@ impl Drop for Builder<'_> { } } +//--- Send, Sync + +// SAFETY: The parts of the referenced message that can be accessed mutably +// are not accessible by any reference other than `self`. +unsafe impl Send for Builder<'_> {} + +// SAFETY: Only parts of the referenced message that are borrowed immutably +// can be accessed through an immutable reference to `self`. +unsafe impl Sync for Builder<'_> {} + //----------- BuilderContext ------------------------------------------------- /// Context for building a DNS message. -#[derive(Clone, Debug)] +/// +/// This type holds auxiliary information necessary for building DNS messages, +/// e.g. name compression state. To construct it, call [`default()`]. +/// +/// [`default()`]: Self::default() +#[derive(Clone, Debug, Default)] pub struct BuilderContext { // TODO: Name compression. /// The current size of the message contents. size: usize, - - /// The maximum size of the message contents. - max_size: usize, -} - -//--- Inspection - -impl BuilderContext { - /// The size of the message contents. - pub fn size(&self) -> usize { - self.size - } - - /// The maximum size of the message contents. - pub fn max_size(&self) -> usize { - self.max_size - } -} - -//--- Default - -impl Default for BuilderContext { - fn default() -> Self { - Self { - size: 0, - max_size: 65535 - core::mem::size_of::
(), - } - } } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index d900b4a1c..27e4cde88 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -60,6 +60,24 @@ impl Message { Self::parse_bytes_by_mut(bytes) .expect("A 12-or-more byte string is a valid 'Message'") } + + /// Truncate the contents of this message to the given size, by pointer. + /// + /// The returned value will have a `contents` field of the given size. + /// + /// # Safety + /// + /// This method uses `pointer::offset()`: `self` must be "derived from a + /// pointer to some allocated object". There must be at least 12 bytes + /// between `self` and the end of that allocated object. A reference to + /// `Message` will always result in a pointer satisfying this. + pub unsafe fn ptr_slice_to(this: *mut Message, size: usize) -> *mut Self { + let bytes = unsafe { core::ptr::addr_of_mut!((*this).contents) }; + let len = unsafe { &*(bytes as *mut [()]) }.len(); + debug_assert!(size <= len); + core::ptr::slice_from_raw_parts_mut(this.cast::(), size) + as *mut Self + } } //----------- Header --------------------------------------------------------- From ef702e300be3bb7268140107e3269739fd353cfd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:41:26 +0100 Subject: [PATCH 090/111] [new_base/build] Add module documentation --- src/new_base/build/mod.rs | 88 +++++++++++++++++++++++++++++++++++++++ src/new_base/question.rs | 56 +++++++++++++++++++++++++ src/new_base/wire/ints.rs | 5 +++ 3 files changed, 149 insertions(+) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 38ab501fe..e723ef521 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -1,4 +1,92 @@ //! Building DNS messages in the wire format. +//! +//! The [`wire`](super::wire) module provides basic serialization capability, +//! but it is not specialized to DNS messages. This module provides that +//! specialization within an ergonomic interface. +//! +//! # The High-Level Interface +//! +//! The core of the high-level interface is [`MessageBuilder`]. It provides +//! the most intuitive methods for appending whole questions and records. +//! +//! ``` +//! use domain::new_base::{Header, HeaderFlags, Question, QType, QClass}; +//! use domain::new_base::build::{BuilderContext, MessageBuilder, BuildIntoMessage}; +//! use domain::new_base::name::RevName; +//! use domain::new_base::wire::U16; +//! +//! // Initialize a DNS message builder. +//! let mut buffer = [0u8; 512]; +//! let mut context = BuilderContext::default(); +//! let mut builder = MessageBuilder::new(&mut buffer, &mut context); +//! +//! // Initialize the message header. +//! let header = builder.header_mut(); +//! *builder.header_mut() = Header { +//! // Select a randomized ID here. +//! id: U16::new(1234), +//! // A recursive query for authoritative data. +//! flags: HeaderFlags::default() +//! .query(0) +//! .set_authoritative(true) +//! .request_recursion(true), +//! counts: Default::default(), +//! }; +//! +//! // Add a question for an A record. +//! // TODO: Use a more ergonomic way to make a name. +//! let name = b"\x00\x03org\x07example\x03www"; +//! let name = unsafe { RevName::from_bytes_unchecked(name) }; +//! let question = Question { +//! qname: name, +//! qtype: QType::A, +//! qclass: QClass::IN, +//! }; +//! builder.append_question(&question).unwrap(); +//! +//! // Use the built message. +//! let message = builder.message(); +//! # let _ = message; +//! ``` +//! +//! # The Low-Level Interface +//! +//! [`Builder`] is a powerful low-level interface that can be used to build +//! DNS messages. It implements atomic building and name compression, and is +//! the foundation of [`MessageBuilder`]. +//! +//! The [`Builder`] interface does not know about questions and records; it is +//! only capable of appending simple bytes and compressing domain names. Its +//! access to the message buffer is limited; it can only append, modify, or +//! truncate the message up to a certain point (all data before that point is +//! immutable). Special attention is given to the message header, as it can +//! be modified at any point in the message building process. +//! +//! ``` +//! use domain::new_base::build::{BuilderContext, Builder, BuildIntoMessage}; +//! use domain::new_rdata::A; +//! +//! // Construct a builder for a particular buffer. +//! let mut buffer = [0u8; 20]; +//! let mut context = BuilderContext::default(); +//! let mut builder = Builder::new(&mut buffer, &mut context); +//! +//! // Try appending some raw bytes to the builder. +//! builder.append_bytes(b"hi! ").unwrap(); +//! assert_eq!(builder.appended(), b"hi! "); +//! +//! // Try appending some structured content to the builder. +//! A::from(std::net::Ipv4Addr::new(127, 0, 0, 1)) +//! .build_into_message(builder.delegate()) +//! .unwrap(); +//! assert_eq!(builder.appended(), b"hi! \x7F\x00\x00\x01"); +//! +//! // Finish using the builder. +//! builder.commit(); +//! +//! // Note: the first 12 bytes hold the message header. +//! assert_eq!(&buffer[12..20], b"hi! \x7F\x00\x00\x01"); +//! ``` mod builder; pub use builder::{Builder, BuilderContext}; diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 720d46e14..e4602a4b6 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -115,6 +115,46 @@ pub struct QType { pub code: U16, } +//--- Associated Constants + +impl QType { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The type of an [`A`](crate::new_rdata::A) record. + pub const A: Self = Self::new(1); + + /// The type of an [`Ns`](crate::new_rdata::Ns) record. + pub const NS: Self = Self::new(2); + + /// The type of a [`CName`](crate::new_rdata::CName) record. + pub const CNAME: Self = Self::new(5); + + /// The type of an [`Soa`](crate::new_rdata::Soa) record. + pub const SOA: Self = Self::new(6); + + /// The type of a [`Wks`](crate::new_rdata::Wks) record. + pub const WKS: Self = Self::new(11); + + /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. + pub const PTR: Self = Self::new(12); + + /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. + pub const HINFO: Self = Self::new(13); + + /// The type of a [`Mx`](crate::new_rdata::Mx) record. + pub const MX: Self = Self::new(15); + + /// The type of a [`Txt`](crate::new_rdata::Txt) record. + pub const TXT: Self = Self::new(16); + + /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. + pub const AAAA: Self = Self::new(28); +} + //----------- QClass --------------------------------------------------------- /// The class of a question. @@ -139,3 +179,19 @@ pub struct QClass { /// The class code. pub code: U16, } + +//--- Associated Constants + +impl QClass { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The Internet class. + pub const IN: Self = Self::new(1); + + /// The CHAOS class. + pub const CH: Self = Self::new(3); +} diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index 3d11f45e4..15834a55f 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -51,6 +51,11 @@ macro_rules! define_int { pub const fn get(self) -> $base { <$base>::from_be_bytes(self.0) } + + /// Overwrite this value with an integer. + pub const fn set(&mut self, value: $base) { + *self = Self::new(value) + } } impl From<$base> for $name { From 0e7346cead2e6317a6ff4dd63876cd13f0774266 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:44:31 +0100 Subject: [PATCH 091/111] [new_base/parse] Add a bit of module documentation --- src/new_base/parse/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index e6d47f4f0..6744c15a6 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,4 +1,9 @@ //! Parsing DNS messages from the wire format. +//! +//! This module provides [`ParseFromMessage`] and [`SplitFromMessage`], which +//! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. +//! When parsing data within a DNS message, these traits allow access to all +//! preceding bytes in the message so that compressed names can be resolved. pub use super::wire::ParseError; From 967d1d5a84eb50becf4a64f3711e8a58ec675070 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:45:10 +0100 Subject: [PATCH 092/111] [new_base/parse] Add missing doc links --- src/new_base/parse/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 6744c15a6..cbe949681 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -4,6 +4,9 @@ //! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. //! When parsing data within a DNS message, these traits allow access to all //! preceding bytes in the message so that compressed names can be resolved. +//! +//! [`ParseBytes`]: super::wire::ParseBytes +//! [`SplitBytes`]: super::wire::SplitBytes pub use super::wire::ParseError; From 4fa09b25db12330c5e02da0a842696cec36b04b6 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:51:39 +0100 Subject: [PATCH 093/111] [new_base/wire/ints] Make 'set()' non-const --- src/new_base/wire/ints.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index 15834a55f..acd2990e2 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -53,7 +53,8 @@ macro_rules! define_int { } /// Overwrite this value with an integer. - pub const fn set(&mut self, value: $base) { + // TODO: Make 'const' at MSRV 1.83.0. + pub fn set(&mut self, value: $base) { *self = Self::new(value) } } From 9e44e48c61740ae9dc4c92689a2782fe331de06d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 14:40:26 +0100 Subject: [PATCH 094/111] [new_base/wire/parse] Fix documentation typo See: --- src/new_base/wire/parse.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 3ee5d44a1..2a8da6642 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -89,8 +89,7 @@ impl<'a> SplitBytes<'a> for u8 { /// Deriving [`SplitBytes`] automatically. /// /// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All -/// fields except the last must implement [`SplitBytes`], while the last field -/// only needs to implement [`SplitBytes`]. +/// fields must implement [`SplitBytes`]. /// /// Here's a simple example: /// From 9af8ea199cb2435919961338525a06d8e62e44ed Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 14:45:27 +0100 Subject: [PATCH 095/111] [new_edns/cookie] Rename 'CookieRequest' to 'ClientCookie' Also fixes typo in field name 'reversed' ('reserved') of 'Cookie'. See: See: --- src/new_edns/cookie.rs | 30 +++++++++++++++--------------- src/new_edns/mod.rs | 20 ++++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 36d96a4f4..77c810be5 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -20,7 +20,7 @@ use crate::new_base::Serial; #[cfg(all(feature = "std", feature = "siphasher"))] use crate::new_base::wire::{AsBytes, TruncationError}; -//----------- CookieRequest -------------------------------------------------- +//----------- ClientCookie --------------------------------------------------- /// A request for a DNS cookie. #[derive( @@ -37,15 +37,15 @@ use crate::new_base::wire::{AsBytes, TruncationError}; SplitBytesByRef, )] #[repr(transparent)] -pub struct CookieRequest { +pub struct ClientCookie { /// The octets of the request. pub octets: [u8; 8], } //--- Construction -impl CookieRequest { - /// Construct a random [`CookieRequest`]. +impl ClientCookie { + /// Construct a random [`ClientCookie`]. #[cfg(feature = "rand")] pub fn random() -> Self { rand::random::<[u8; 8]>().into() @@ -54,7 +54,7 @@ impl CookieRequest { //--- Interaction -impl CookieRequest { +impl ClientCookie { /// Build a [`Cookie`] in response to this request. /// /// A 24-byte version-1 interoperable cookie will be generated and written @@ -101,27 +101,27 @@ impl CookieRequest { //--- Conversion to and from octets -impl From<[u8; 8]> for CookieRequest { +impl From<[u8; 8]> for ClientCookie { fn from(value: [u8; 8]) -> Self { Self { octets: value } } } -impl From for [u8; 8] { - fn from(value: CookieRequest) -> Self { +impl From for [u8; 8] { + fn from(value: ClientCookie) -> Self { value.octets } } //--- Formatting -impl fmt::Debug for CookieRequest { +impl fmt::Debug for ClientCookie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "CookieRequest({})", self) + write!(f, "ClientCookie({})", self) } } -impl fmt::Display for CookieRequest { +impl fmt::Display for ClientCookie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:016X}", u64::from_be_bytes(self.octets)) } @@ -135,14 +135,14 @@ impl fmt::Display for CookieRequest { )] #[repr(C)] pub struct Cookie { - /// The request for this cookie. - request: CookieRequest, + /// The client's request for this cookie. + request: ClientCookie, /// The version number of this cookie. version: u8, /// Reserved bytes in the cookie format. - reversed: [u8; 3], + reserved: [u8; 3], /// When this cookie was made. timestamp: Serial, @@ -155,7 +155,7 @@ pub struct Cookie { impl Cookie { /// The underlying cookie request. - pub fn request(&self) -> &CookieRequest { + pub fn request(&self) -> &ClientCookie { &self.request } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 152cd5dae..8bc3c55b6 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -21,7 +21,7 @@ use crate::{ //----------- EDNS option modules -------------------------------------------- mod cookie; -pub use cookie::{Cookie, CookieRequest}; +pub use cookie::{ClientCookie, Cookie}; mod ext_err; pub use ext_err::{ExtError, ExtErrorCode}; @@ -212,10 +212,10 @@ impl fmt::Debug for EdnsFlags { #[derive(Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { - /// A request for a DNS cookie. - CookieRequest(&'b CookieRequest), + /// A client's request for a DNS cookie. + ClientCookie(&'b ClientCookie), - /// A DNS cookie. + /// A server-provided DNS cookie. Cookie(&'b Cookie), /// An extended DNS error. @@ -231,7 +231,7 @@ impl EdnsOption<'_> { /// The code for this option. pub fn code(&self) -> OptionCode { match self { - Self::CookieRequest(_) => OptionCode::COOKIE, + Self::ClientCookie(_) => OptionCode::COOKIE, Self::Cookie(_) => OptionCode::COOKIE, Self::ExtError(_) => OptionCode::EXT_ERROR, Self::Unknown(code, _) => *code, @@ -248,8 +248,8 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { match code { OptionCode::COOKIE => match data.len() { - 8 => CookieRequest::parse_bytes_by_ref(data) - .map(Self::CookieRequest), + 8 => ClientCookie::parse_bytes_by_ref(data) + .map(Self::ClientCookie), 16..=40 => Cookie::parse_bytes_by_ref(data).map(Self::Cookie), _ => Err(ParseError), }, @@ -273,8 +273,8 @@ impl<'b> SplitBytes<'b> for EdnsOption<'b> { let this = match code { OptionCode::COOKIE => match data.len() { - 8 => <&CookieRequest>::parse_bytes(data) - .map(Self::CookieRequest)?, + 8 => <&ClientCookie>::parse_bytes(data) + .map(Self::ClientCookie)?, 16..=40 => <&Cookie>::parse_bytes(data).map(Self::Cookie)?, _ => return Err(ParseError), }, @@ -301,7 +301,7 @@ impl BuildBytes for EdnsOption<'_> { bytes = self.code().build_bytes(bytes)?; let data = match self { - Self::CookieRequest(this) => this.as_bytes(), + Self::ClientCookie(this) => this.as_bytes(), Self::Cookie(this) => this.as_bytes(), Self::ExtError(this) => this.as_bytes(), Self::Unknown(_, this) => this.as_bytes(), From c5f1552c5ee4745b68b8f5fd68bfbcce3bf64b4f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 15:22:01 +0100 Subject: [PATCH 096/111] [new_base/parse] Refactor '{Parse,Split}FromMessage' - The traits have been renamed to '{Parse,Split}MessageBytes' for consistency with '{Parse,Split}Bytes'. - The traits now consume 'message.contents' instead of the whole 'message', allowing them to be used in more contexts (e.g. the upcoming overhauled builder types). --- src/new_base/charstr.rs | 26 ++++------ src/new_base/name/reversed.rs | 17 +++--- src/new_base/name/unparsed.rs | 19 ++++--- src/new_base/parse/mod.rs | 63 +++++++++++------------ src/new_base/question.rs | 31 ++++++----- src/new_base/record.rs | 50 +++++++++--------- src/new_base/wire/size_prefixed.rs | 30 +++++------ src/new_edns/mod.rs | 26 ++++------ src/new_rdata/basic.rs | 83 ++++++++++++++---------------- src/new_rdata/mod.rs | 38 ++++++++------ 10 files changed, 179 insertions(+), 204 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 8df3c3d7c..3eab4d7cd 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -4,9 +4,8 @@ use core::fmt; use super::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, }; //----------- CharStr -------------------------------------------------------- @@ -20,27 +19,22 @@ pub struct CharStr { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a CharStr { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for &'a CharStr { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, bytes.len() - rest.len())) + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a> ParseFromMessage<'a> for &'a CharStr { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a CharStr { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index f33451f3b..55a83e82e 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -10,9 +10,8 @@ use core::{ use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, }; use super::LabelIter; @@ -239,9 +238,9 @@ impl RevNameBuf { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for RevNameBuf { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for RevNameBuf { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { // NOTE: The input may be controlled by an attacker. Compression @@ -251,7 +250,6 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { // disallow a name to point to data _after_ it. Standard name // compressors will never generate such pointers. - let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. @@ -282,16 +280,15 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } } -impl<'a> ParseFromMessage<'a> for RevNameBuf { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for RevNameBuf { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { // See 'split_from_message()' for details. The only differences are // in the range of the first iteration, and the check that the first // iteration exactly covers the input range. - let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs index 828c92229..4a76bee6c 100644 --- a/src/new_base/name/unparsed.rs +++ b/src/new_base/name/unparsed.rs @@ -3,9 +3,8 @@ use domain_macros::*; use crate::new_base::{ - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::ParseError, - Message, }; //----------- UnparsedName --------------------------------------------------- @@ -77,12 +76,12 @@ impl UnparsedName { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a UnparsedName { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for &'a UnparsedName { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; + let bytes = &contents[start..]; let mut offset = 0; let offset = loop { match bytes[offset..] { @@ -120,12 +119,12 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { } } -impl<'a> ParseFromMessage<'a> for &'a UnparsedName { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a UnparsedName { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; + let bytes = &contents[start..]; let mut offset = 0; loop { match bytes[offset..] { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index cbe949681..03ac16995 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,69 +1,68 @@ //! Parsing DNS messages from the wire format. //! -//! This module provides [`ParseFromMessage`] and [`SplitFromMessage`], which -//! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. -//! When parsing data within a DNS message, these traits allow access to all -//! preceding bytes in the message so that compressed names can be resolved. +//! This module provides [`ParseMessageBytes`] and [`SplitMessageBytes`], +//! which are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS +//! messages. When parsing data within a DNS message, these traits allow +//! access to all preceding bytes in the message so that compressed names can +//! be resolved. //! //! [`ParseBytes`]: super::wire::ParseBytes //! [`SplitBytes`]: super::wire::SplitBytes pub use super::wire::ParseError; -use super::{ - wire::{ParseBytesByRef, SplitBytesByRef}, - Message, -}; +use super::wire::{ParseBytesByRef, SplitBytesByRef}; -//----------- Message-aware parsing traits ----------------------------------- +//----------- Message parsing traits ----------------------------------------- /// A type that can be parsed from a DNS message. -pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { +pub trait SplitMessageBytes<'a>: Sized + ParseMessageBytes<'a> { /// Parse a value from the start of a byte string within a DNS message. /// - /// The byte string to parse is `message.contents[start..]`. The previous - /// data in the message can be used for resolving compressed names. + /// The contents of the DNS message is provided as `contents`. + /// `contents[start..]` is the beginning of the input to be parsed. The + /// earlier bytes are provided for resolving compressed domain names. /// /// If parsing is successful, the parsed value and the offset for the rest /// of the input are returned. If `len` bytes were parsed to form `self`, /// `start + len` should be the returned offset. - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError>; } -/// A type that can be parsed from a string in a DNS message. -pub trait ParseFromMessage<'a>: Sized { - /// Parse a value from a byte string within a DNS message. +/// A type that can be parsed from bytes in a DNS message. +pub trait ParseMessageBytes<'a>: Sized { + /// Parse a value from bytes in a DNS message. /// - /// The byte string to parse is `message.contents[start..]`. The previous - /// data in the message can be used for resolving compressed names. + /// The contents of the DNS message (up to and including the actual bytes + /// to be parsed) is provided as `contents`. `contents[start..]` is the + /// input to be parsed. The earlier bytes are provided for resolving + /// compressed domain names. /// /// If parsing is successful, the parsed value is returned. - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result; } -impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { - fn split_from_message( - message: &'a Message, +impl<'a, T: ?Sized + SplitBytesByRef> SplitMessageBytes<'a> for &'a T { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = T::split_bytes_by_ref(bytes)?; - Ok((this, bytes.len() - rest.len())) + T::split_bytes_by_ref(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { - fn parse_from_message( - message: &'a Message, +impl<'a, T: ?Sized + ParseBytesByRef> ParseMessageBytes<'a> for &'a T { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - T::parse_bytes_by_ref(bytes) + T::parse_bytes_by_ref(&contents[start..]) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index e4602a4b6..04dfa2d9e 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -5,9 +5,8 @@ use domain_macros::*; use super::{ build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{AsBytes, ParseError, U16}, - Message, }; //----------- Question ------------------------------------------------------- @@ -43,32 +42,32 @@ impl Question { //--- Parsing from DNS messages -impl<'a, N> SplitFromMessage<'a> for Question +impl<'a, N> SplitMessageBytes<'a> for Question where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, { - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (qname, rest) = N::split_from_message(message, start)?; - let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let (&qclass, rest) = <&QClass>::split_from_message(message, rest)?; + let (qname, rest) = N::split_message_bytes(contents, start)?; + let (&qtype, rest) = <&QType>::split_message_bytes(contents, rest)?; + let (&qclass, rest) = <&QClass>::split_message_bytes(contents, rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a, N> ParseFromMessage<'a> for Question +impl<'a, N> ParseMessageBytes<'a> for Question where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, { - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (qname, rest) = N::split_from_message(message, start)?; - let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let &qclass = <&QClass>::parse_from_message(message, rest)?; + let (qname, rest) = N::split_message_bytes(contents, start)?; + let (&qtype, rest) = <&QType>::split_message_bytes(contents, rest)?; + let &qclass = <&QClass>::parse_message_bytes(contents, rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 742a66977..0c0e6fa3c 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,12 +5,11 @@ use core::{borrow::Borrow, ops::Deref}; use super::{ build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, - Message, }; //----------- Record --------------------------------------------------------- @@ -60,44 +59,44 @@ impl Record { //--- Parsing from DNS messages -impl<'a, N, D> SplitFromMessage<'a> for Record +impl<'a, N, D> SplitMessageBytes<'a> for Record where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, D: ParseRecordData<'a>, { - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (rname, rest) = N::split_from_message(message, start)?; - let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; - let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; - let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let (rname, rest) = N::split_message_bytes(contents, start)?; + let (&rtype, rest) = <&RType>::split_message_bytes(contents, rest)?; + let (&rclass, rest) = <&RClass>::split_message_bytes(contents, rest)?; + let (&ttl, rest) = <&TTL>::split_message_bytes(contents, rest)?; let rdata_start = rest; let (_, rest) = - <&SizePrefixed<[u8]>>::split_from_message(message, rest)?; - let message = message.slice_to(rest); - let rdata = D::parse_record_data(message, rdata_start, rtype)?; + <&SizePrefixed<[u8]>>::split_message_bytes(contents, rest)?; + let rdata = + D::parse_record_data(&contents[..rest], rdata_start, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } -impl<'a, N, D> ParseFromMessage<'a> for Record +impl<'a, N, D> ParseMessageBytes<'a> for Record where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, D: ParseRecordData<'a>, { - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (rname, rest) = N::split_from_message(message, start)?; - let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; - let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; - let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; - let _ = <&SizePrefixed<[u8]>>::parse_from_message(message, rest)?; - let rdata = D::parse_record_data(message, rest, rtype)?; + let (rname, rest) = N::split_message_bytes(contents, start)?; + let (&rtype, rest) = <&RType>::split_message_bytes(contents, rest)?; + let (&rclass, rest) = <&RClass>::split_message_bytes(contents, rest)?; + let (&ttl, rest) = <&TTL>::split_message_bytes(contents, rest)?; + let _ = <&SizePrefixed<[u8]>>::parse_message_bytes(contents, rest)?; + let rdata = D::parse_record_data(contents, rest, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -305,12 +304,11 @@ pub struct TTL { pub trait ParseRecordData<'a>: Sized { /// Parse DNS record data of the given type from a DNS message. fn parse_record_data( - message: &'a Message, + contents: &'a [u8], start: usize, rtype: RType, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - Self::parse_record_data_bytes(bytes, rtype) + Self::parse_record_data_bytes(&contents[start..], rtype) } /// Parse DNS record data of the given type from a byte string. diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 751a57395..5ac9effa9 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -7,8 +7,7 @@ use core::{ use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, - Message, + parse::{ParseMessageBytes, SplitMessageBytes}, }; use super::{ @@ -110,31 +109,28 @@ impl AsMut for SizePrefixed { //--- Parsing from DNS messages -impl<'b, T: ParseFromMessage<'b>> ParseFromMessage<'b> for SizePrefixed { - fn parse_from_message( - message: &'b Message, +impl<'b, T: ParseMessageBytes<'b>> ParseMessageBytes<'b> for SizePrefixed { + fn parse_message_bytes( + contents: &'b [u8], start: usize, ) -> Result { - let (&size, rest) = <&U16>::split_from_message(message, start)?; - if rest + size.get() as usize != message.contents.len() { + let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; + if rest + size.get() as usize != contents.len() { return Err(ParseError); } - T::parse_from_message(message, rest).map(Self::new) + T::parse_message_bytes(contents, rest).map(Self::new) } } -impl<'b, T: ParseFromMessage<'b>> SplitFromMessage<'b> for SizePrefixed { - fn split_from_message( - message: &'b Message, +impl<'b, T: ParseMessageBytes<'b>> SplitMessageBytes<'b> for SizePrefixed { + fn split_message_bytes( + contents: &'b [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (&size, rest) = <&U16>::split_from_message(message, start)?; + let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; let (start, rest) = (rest, rest + size.get() as usize); - if rest > message.contents.len() { - return Err(ParseError); - } - let message = message.slice_to(rest); - let data = T::parse_from_message(message, start)?; + let contents = contents.get(..rest).ok_or(ParseError)?; + let data = T::parse_message_bytes(contents, start)?; Ok((Self::new(data), rest)) } } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 8bc3c55b6..96233221d 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -8,12 +8,11 @@ use domain_macros::*; use crate::{ new_base::{ - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, TruncationError, U16, }, - Message, }, new_rdata::Opt, }; @@ -49,27 +48,22 @@ pub struct EdnsRecord<'a> { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for EdnsRecord<'a> { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, message.contents.len() - rest.len())) + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for EdnsRecord<'a> { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 456da881c..43d089100 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -14,9 +14,9 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{AsBytes, ParseBytes, ParseError, SplitBytes, U16, U32}, - CharStr, Message, Serial, + CharStr, Serial, }; //----------- A -------------------------------------------------------------- @@ -114,12 +114,12 @@ pub struct Ns { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -155,12 +155,12 @@ pub struct CName { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -211,18 +211,18 @@ pub struct Soa { //--- Parsing from DNS messages -impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { - fn parse_from_message( - message: &'a Message, +impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (mname, rest) = N::split_from_message(message, start)?; - let (rname, rest) = N::split_from_message(message, rest)?; - let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; - let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; - let (&retry, rest) = <&U32>::split_from_message(message, rest)?; - let (&expire, rest) = <&U32>::split_from_message(message, rest)?; - let &minimum = <&U32>::parse_from_message(message, rest)?; + let (mname, rest) = N::split_message_bytes(contents, start)?; + let (rname, rest) = N::split_message_bytes(contents, rest)?; + let (&serial, rest) = <&Serial>::split_message_bytes(contents, rest)?; + let (&refresh, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&retry, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&expire, rest) = <&U32>::split_message_bytes(contents, rest)?; + let &minimum = <&U32>::parse_message_bytes(contents, rest)?; Ok(Self { mname, @@ -330,12 +330,12 @@ pub struct Ptr { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -361,16 +361,12 @@ pub struct HInfo<'a> { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for HInfo<'a> { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for HInfo<'a> { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } @@ -414,13 +410,14 @@ pub struct Mx { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (&preference, rest) = <&U16>::split_from_message(message, start)?; - let exchange = N::parse_from_message(message, rest)?; + let (&preference, rest) = + <&U16>::split_message_bytes(contents, start)?; + let exchange = N::parse_message_bytes(contents, rest)?; Ok(Self { preference, exchange, @@ -471,16 +468,12 @@ impl Txt { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for &'a Txt { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a Txt { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index e4b94a538..70f041240 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -4,9 +4,9 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, ParseRecordData, RType, + ParseRecordData, RType, }; //----------- Concrete record data types ------------------------------------- @@ -67,42 +67,48 @@ pub enum RecordData<'a, N> { impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> where - N: SplitBytes<'a> + SplitFromMessage<'a>, + N: SplitBytes<'a> + SplitMessageBytes<'a>, { fn parse_record_data( - message: &'a Message, + contents: &'a [u8], start: usize, rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from_message(message, start).map(Self::A), - RType::NS => Ns::parse_from_message(message, start).map(Self::Ns), + RType::A => { + <&A>::parse_message_bytes(contents, start).map(Self::A) + } + RType::NS => { + Ns::parse_message_bytes(contents, start).map(Self::Ns) + } RType::CNAME => { - CName::parse_from_message(message, start).map(Self::CName) + CName::parse_message_bytes(contents, start).map(Self::CName) } RType::SOA => { - Soa::parse_from_message(message, start).map(Self::Soa) + Soa::parse_message_bytes(contents, start).map(Self::Soa) } RType::WKS => { - <&Wks>::parse_from_message(message, start).map(Self::Wks) + <&Wks>::parse_message_bytes(contents, start).map(Self::Wks) } RType::PTR => { - Ptr::parse_from_message(message, start).map(Self::Ptr) + Ptr::parse_message_bytes(contents, start).map(Self::Ptr) } RType::HINFO => { - HInfo::parse_from_message(message, start).map(Self::HInfo) + HInfo::parse_message_bytes(contents, start).map(Self::HInfo) + } + RType::MX => { + Mx::parse_message_bytes(contents, start).map(Self::Mx) } - RType::MX => Mx::parse_from_message(message, start).map(Self::Mx), RType::TXT => { - <&Txt>::parse_from_message(message, start).map(Self::Txt) + <&Txt>::parse_message_bytes(contents, start).map(Self::Txt) } RType::AAAA => { - <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) + <&Aaaa>::parse_message_bytes(contents, start).map(Self::Aaaa) } RType::OPT => { - <&Opt>::parse_from_message(message, start).map(Self::Opt) + <&Opt>::parse_message_bytes(contents, start).map(Self::Opt) } - _ => <&UnknownRecordData>::parse_from_message(message, start) + _ => <&UnknownRecordData>::parse_message_bytes(contents, start) .map(|data| Self::Unknown(rtype, data)), } } From e2729478571af3f96ac227fa31516758bbe6b4ed Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 23 Jan 2025 16:02:36 +0100 Subject: [PATCH 097/111] [new_base/build] Overhaul message building - 'BuilderContext' now tracks the last question / record in a message, allowing its building to be recovered in the future (WIP). - 'MessageBuilder' inlines the 'Builder' and just stores a mutable reference to a 'Message' for simplicity. - 'Builder' no longer gives access to the message header, and uses '&UnsafeCell<[u8]>' to represent the message contents. --- src/new_base/build/builder.rs | 309 ++++++++-------------------- src/new_base/build/context.rs | 135 +++++++++++++ src/new_base/build/message.rs | 290 ++++++++++++--------------- src/new_base/build/mod.rs | 51 +---- src/new_base/build/question.rs | 106 ++++++++++ src/new_base/build/record.rs | 310 +++++++++++++++++------------ src/new_base/wire/size_prefixed.rs | 6 +- 7 files changed, 647 insertions(+), 560 deletions(-) create mode 100644 src/new_base/build/context.rs create mode 100644 src/new_base/build/question.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 16444349c..edcc9b543 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -1,18 +1,18 @@ //! A builder for DNS messages. use core::{ - marker::PhantomData, + cell::UnsafeCell, mem::ManuallyDrop, - ptr::{self, NonNull}, + ptr::{self}, + slice, }; use crate::new_base::{ name::RevName, - wire::{BuildBytes, ParseBytesByRef, TruncationError}, - Header, Message, + wire::{BuildBytes, TruncationError}, }; -use super::BuildCommitted; +use super::{BuildCommitted, BuilderContext}; //----------- Builder -------------------------------------------------------- @@ -60,144 +60,49 @@ use super::BuildCommitted; /// will be reverted on failure). For this, we have [`delegate()`]. /// /// [`delegate()`]: Self::delegate() -/// -/// For example: -/// -/// ``` -/// # use domain::new_base::build::{Builder, BuildResult, BuilderContext}; -/// -/// /// A build function with the conventional type signature. -/// fn foo(mut builder: Builder<'_>) -> BuildResult { -/// // Content added by the parent builder is considered committed. -/// assert_eq!(builder.committed(), b"hi! "); -/// -/// // Append some content to the builder. -/// builder.append_bytes(b"foo!")?; -/// -/// // Try appending a very long string, which can't fit. -/// builder.append_bytes(b"helloworldthisiswaytoobig")?; -/// -/// Ok(builder.commit()) -/// } -/// -/// // Construct a builder for a particular buffer. -/// let mut buffer = [0u8; 20]; -/// let mut context = BuilderContext::default(); -/// let mut builder = Builder::new(&mut buffer, &mut context); -/// -/// // Try appending some content to the builder. -/// builder.append_bytes(b"hi! ").unwrap(); -/// assert_eq!(builder.appended(), b"hi! "); -/// -/// // Try calling 'foo' -- note that it will fail. -/// // Note that we delegated the builder. -/// foo(builder.delegate()).unwrap_err(); -/// -/// // No partial content was written. -/// assert_eq!(builder.appended(), b"hi! "); -/// ``` pub struct Builder<'b> { - /// The message being built. + /// The contents of the built message. /// - /// The message is divided into four parts: + /// The buffer is divided into three parts: /// - /// - The message header (borrowed mutably by this type). /// - Committed message contents (borrowed *immutably* by this type). /// - Appended message contents (borrowed mutably by this type). /// - Uninitialized message contents (borrowed mutably by this type). - message: NonNull, - - _message: PhantomData<&'b mut Message>, + contents: &'b UnsafeCell<[u8]>, /// Context for building. context: &'b mut BuilderContext, - /// The commit point of this builder. + /// The start point of this builder. /// /// Message contents up to this point are committed and cannot be removed /// by this builder. Message contents following this (up to the size in /// the builder context) are appended but uncommitted. - commit: usize, + start: usize, } -/// # Initialization -/// -/// In order to begin building a DNS message: -/// -/// ``` -/// # use domain::new_base::build::{Builder, BuilderContext}; -/// -/// // Allocate a slice of 'u8's somewhere. -/// let mut buffer = [0u8; 20]; -/// -/// // Obtain a builder context. -/// // -/// // The value doesn't matter, it will be overwritten. -/// let mut context = BuilderContext::default(); -/// -/// // Construct the actual 'Builder'. -/// let builder = Builder::new(&mut buffer, &mut context); -/// -/// assert!(builder.committed().is_empty()); -/// assert!(builder.appended().is_empty()); -/// ``` impl<'b> Builder<'b> { - /// Create a [`Builder`] for a new, empty DNS message. - /// - /// The message header is left uninitialized. Use [`Self::header_mut()`] - /// to initialize it. The message contents are completely empty. - /// - /// The provided builder context will be overwritten with a default state. - /// - /// # Panics - /// - /// Panics if the buffer is less than 12 bytes long (which is the minimum - /// possible size for a DNS message). - pub fn new( - buffer: &'b mut [u8], - context: &'b mut BuilderContext, - ) -> Self { - let message = Message::parse_bytes_by_mut(buffer) - .expect("The buffure must be at least 12 bytes in size"); - context.size = 0; - - // SAFETY: 'message' and 'context' are now consistent. - unsafe { Self::from_raw_parts(message.into(), context, 0) } - } - /// Construct a [`Builder`] from raw parts. /// - /// The provided components must originate from [`into_raw_parts()`], and - /// none of the components can be modified since they were extracted. - /// - /// [`into_raw_parts()`]: Self::into_raw_parts() - /// - /// This method is useful when overcoming limitations in lifetimes or - /// borrow checking, or when a builder has to be constructed from another - /// with specific characteristics. - /// /// # Safety /// - /// The expression `from_raw_parts(message, context, commit)` is sound if + /// The expression `from_raw_parts(contents, context, start)` is sound if /// and only if all of the following conditions are satisfied: /// - /// - `message` is a valid reference for the lifetime `'b`. - /// - `message.header` is mutably borrowed for `'b`. - /// - `message.contents[..commit]` is immutably borrowed for `'b`. - /// - `message.contents[commit..]` is mutably borrowed for `'b`. + /// - `message[..start]` is immutably borrowed for `'b`. + /// - `message[start..]` is mutably borrowed for `'b`. /// /// - `message` and `context` originate from the same builder. - /// - `commit <= context.size() <= message.contents.len()`. + /// - `start <= context.size() <= message.len()`. pub unsafe fn from_raw_parts( - message: NonNull, + contents: &'b UnsafeCell<[u8]>, context: &'b mut BuilderContext, - commit: usize, + start: usize, ) -> Self { Self { - message, - _message: PhantomData, + contents, context, - commit, + start, } } } @@ -210,38 +115,30 @@ impl<'b> Builder<'b> { /// ```text /// name | position /// --------------+--------- -/// header | -/// committed | 0 .. commit -/// appended | commit .. size -/// uninitialized | size .. limit +/// committed | 0 .. start +/// appended | start .. offset +/// uninitialized | offset .. limit /// inaccessible | limit .. /// ``` /// -/// The DNS message header can be modified at any time. It is made available -/// through [`header()`] and [`header_mut()`]. In general, it is inadvisable -/// to change the section counts arbitrarily (although it will not cause -/// undefined behaviour). -/// -/// [`header()`]: Self::header() -/// [`header_mut()`]: Self::header_mut() -/// /// The committed content of the builder is immutable, and is available to /// reference, through [`committed()`], for the lifetime `'b`. /// /// [`committed()`]: Self::committed() /// -/// The appended content of the builder is made available via [`appended()`]. -/// It is content that has been added by this builder, but that has not yet -/// been committed. When the [`Builder`] is dropped, this content is removed -/// (it becomes uninitialized). Appended content can be modified, but any -/// compressed names within it have to be handled with great care; they can -/// only be modified by removing them entirely (by rewinding the builder, -/// using [`rewind()`]) and building them again. When compressed names are -/// guaranteed to not be modified, [`appended_mut()`] can be used. +/// The appended but uncommitted content of the builder is made available via +/// [`uncommitted_mut()`]. It is content that has been added by this builder, +/// but that has not yet been committed. When the [`Builder`] is dropped, +/// this content is removed (it becomes uninitialized). Appended content can +/// be modified, but any compressed names within it have to be handled with +/// great care; they can only be modified by removing them entirely (by +/// rewinding the builder, using [`rewind()`]) and building them again. When +/// compressed names are guaranteed to not be modified, [`uncommitted_mut()`] +/// can be used. /// /// [`appended()`]: Self::appended() /// [`rewind()`]: Self::rewind() -/// [`appended_mut()`]: Self::appended_mut() +/// [`uncommitted_mut()`]: Self::uncommitted_mut() /// /// The uninitialized space in the builder will be written to when appending /// new content. It can be accessed directly, in case that is more efficient @@ -259,37 +156,31 @@ impl<'b> Builder<'b> { /// /// [`limit_to()`]: Self::limit_to() impl<'b> Builder<'b> { - /// The header of the DNS message. - pub fn header(&self) -> &Header { - // SAFETY: 'message.header' is mutably borrowed by 'self'. - unsafe { &(*self.message.as_ptr()).header } - } - - /// The header of the DNS message, mutably. - /// - /// It is possible to modify the section counts arbitrarily through this - /// method; while doing so cannot cause undefined behaviour, it is not - /// recommended. - pub fn header_mut(&mut self) -> &mut Header { - // SAFETY: 'message.header' is mutably borrowed by 'self'. - unsafe { &mut (*self.message.as_ptr()).header } - } - /// Committed message contents. pub fn committed(&self) -> &'b [u8] { - // SAFETY: 'message.contents[..commit]' is immutably borrowed by - // 'self'. - unsafe { &(*self.message.as_ptr()).contents[..self.commit] } + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..start]' is immutably borrowed. + unsafe { slice::from_raw_parts(message, self.start) } + } + + /// Appended (and committed) message contents. + pub fn appended(&self) -> &[u8] { + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..offset]' is (im)mutably borrowed. + unsafe { slice::from_raw_parts(message, self.context.size) } } /// The appended but uncommitted contents of the message. /// /// The builder can modify or rewind these contents, so they are offered /// with a short lifetime. - pub fn appended(&self) -> &[u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.commit..self.context.size; - unsafe { &(*self.message.as_ptr()).contents[range] } + pub fn uncommitted(&self) -> &[u8] { + let message = self.contents.get().cast::().cast_const(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.offset(self.start as isize) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts(message, size) } } /// The appended but uncommitted contents of the message, mutably. @@ -298,10 +189,13 @@ impl<'b> Builder<'b> { /// /// The caller must not modify any compressed names among these bytes. /// This can invalidate name compression state. - pub unsafe fn appended_mut(&mut self) -> &mut [u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.commit..self.context.size; - unsafe { &mut (*self.message.as_ptr()).contents[range] } + pub unsafe fn uncommitted_mut(&mut self) -> &mut [u8] { + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.offset(self.start as isize) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } } /// Uninitialized space in the message buffer. @@ -310,39 +204,12 @@ impl<'b> Builder<'b> { /// should be treated as appended content in the message, call /// [`self.mark_appended(n)`](Self::mark_appended()). pub fn uninitialized(&mut self) -> &mut [u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - unsafe { &mut (*self.message.as_ptr()).contents[self.context.size..] } - } - - /// The message with all committed contents. - /// - /// The header of the message can be modified by the builder, so the - /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`committed()`]. The message - /// does not include content that has been appended but not committed. - /// - /// [`committed()`]: Self::committed() - pub fn message(&self) -> &Message { - // SAFETY: All of 'message' can be immutably borrowed by 'self'. - unsafe { self.message.as_ref() }.slice_to(self.commit) - } - - /// The message including any uncommitted contents. - pub fn cur_message(&self) -> &Message { - // SAFETY: All of 'message' can be immutably borrowed by 'self'. - unsafe { self.message.as_ref() }.slice_to(self.context.size) - } - - /// A pointer to the message, including any uncommitted contents. - /// - /// The first `commit` bytes of the message contents (also provided by - /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. - /// The remainder of the message is initialized and borrowed by `self`. - pub fn cur_message_ptr(&mut self) -> NonNull { - let message = self.message.as_ptr(); - let size = self.context.size; - let message = unsafe { Message::ptr_slice_to(message, size) }; - unsafe { NonNull::new_unchecked(message) } + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'size <= message.len()'. + let message = unsafe { message.offset(self.context.size as isize) }; + let size = self.max_size() - self.context.size; + // SAFETY: 'message[size..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } } /// The builder context. @@ -356,7 +223,15 @@ impl<'b> Builder<'b> { /// initialized. The content before this point has been committed and is /// immutable. The builder can be rewound up to this point. pub fn start(&self) -> usize { - self.commit + self.start + } + + /// The append point of this builder. + /// + /// This is the offset into the message contents at which new data will be + /// written. The content after this point is uninitialized. + pub fn offset(&self) -> usize { + self.context.size } /// The size limit of this builder. @@ -365,12 +240,11 @@ impl<'b> Builder<'b> { /// [`TruncationError`]s will occur. The limit can be tightened using /// [`limit_to()`](Self::limit_to()). pub fn max_size(&self) -> usize { - // SAFETY: 'Message' ends with a slice DST, and so references to it - // hold the length of that slice; we can cast it to another slice type - // and the pointer representation is unchanged. By using a slice type - // of ZST elements, aliasing is impossible, and it can be dereferenced + // SAFETY: We can cast 'contents' to another slice type and the + // pointer representation is unchanged. By using a slice type of ZST + // elements, aliasing is impossible, and it can be dereferenced // safely. - unsafe { &*(self.message.as_ptr() as *mut [()]) }.len() + unsafe { &*(self.contents.get() as *mut [()]) }.len() } /// Decompose this builder into raw parts. @@ -389,14 +263,14 @@ impl<'b> Builder<'b> { /// The builder can be recomposed with [`Self::from_raw_parts()`]. pub fn into_raw_parts( self, - ) -> (NonNull, &'b mut BuilderContext, usize) { + ) -> (&'b UnsafeCell<[u8]>, &'b mut BuilderContext, usize) { // NOTE: The context has to be moved out carefully. - let (message, commit) = (self.message, self.commit); + let (contents, start) = (self.contents, self.start); let this = ManuallyDrop::new(self); let this = (&*this) as *const Self; // SAFETY: 'this' is a valid object that can be moved out of. let context = unsafe { ptr::read(ptr::addr_of!((*this).context)) }; - (message, context, commit) + (contents, context, start) } } @@ -437,7 +311,7 @@ impl<'b> Builder<'b> { impl Builder<'_> { /// Rewind the builder, removing all uncommitted content. pub fn rewind(&mut self) { - self.context.size = self.commit; + self.context.size = self.start; } /// Commit the changes made by this builder. @@ -447,7 +321,7 @@ impl Builder<'_> { /// this method on success paths. pub fn commit(mut self) -> BuildCommitted { // Update 'commit' so that the drop glue is a no-op. - self.commit = self.context.size; + self.start = self.context.size; BuildCommitted } @@ -459,9 +333,11 @@ impl Builder<'_> { /// limit, a [`TruncationError`] is returned. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { if self.context.size <= size { - let message = self.message.as_ptr(); - let message = unsafe { Message::ptr_slice_to(message, size) }; - self.message = unsafe { NonNull::new_unchecked(message) }; + let message = self.contents.get().cast::(); + debug_assert!(size <= self.max_size()); + self.contents = unsafe { + &*(ptr::slice_from_raw_parts_mut(message, size) as *const _) + }; Ok(()) } else { Err(TruncationError) @@ -490,7 +366,7 @@ impl Builder<'_> { pub fn delegate(&mut self) -> Builder<'_> { let commit = self.context.size; unsafe { - Builder::from_raw_parts(self.message, &mut *self.context, commit) + Builder::from_raw_parts(self.contents, &mut *self.context, commit) } } @@ -550,18 +426,3 @@ unsafe impl Send for Builder<'_> {} // SAFETY: Only parts of the referenced message that are borrowed immutably // can be accessed through an immutable reference to `self`. unsafe impl Sync for Builder<'_> {} - -//----------- BuilderContext ------------------------------------------------- - -/// Context for building a DNS message. -/// -/// This type holds auxiliary information necessary for building DNS messages, -/// e.g. name compression state. To construct it, call [`default()`]. -/// -/// [`default()`]: Self::default() -#[derive(Clone, Debug, Default)] -pub struct BuilderContext { - // TODO: Name compression. - /// The current size of the message contents. - size: usize, -} diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs new file mode 100644 index 000000000..2f7f43da1 --- /dev/null +++ b/src/new_base/build/context.rs @@ -0,0 +1,135 @@ +//! Context for building DNS messages. + +//----------- BuilderContext ------------------------------------------------- + +/// Context for building a DNS message. +/// +/// This type holds auxiliary information necessary for building DNS messages, +/// e.g. name compression state. To construct it, call [`default()`]. +/// +/// [`default()`]: Self::default() +#[derive(Clone, Debug, Default)] +pub struct BuilderContext { + // TODO: Name compression. + /// The current size of the message contents. + pub size: usize, + + /// The state of the DNS message. + pub state: MessageState, +} + +//----------- MessageState --------------------------------------------------- + +/// The state of a DNS message being built. +/// +/// A DNS message consists of a header, questions, answers, authorities, and +/// additionals. [`MessageState`] remembers the start position of the last +/// question or record in the message, allowing it to be modifying or removed +/// (for additional flexibility in the building process). +#[derive(Clone, Debug, Default)] +pub enum MessageState { + /// Questions are being built. + /// + /// The message already contains zero or more DNS questions. If there is + /// a last DNS question, its start position is unknown, so it cannot be + /// modified or removed. + /// + /// This is the default state for an empty message. + #[default] + Questions, + + /// A question is being built. + /// + /// The message contains one or more DNS questions. The last question can + /// be modified or truncated. + MidQuestion { + /// The offset of the question name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + }, + + /// Answer records are being built. + /// + /// The message already contains zero or more DNS answer records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Answers, + + /// An answer record is being built. + /// + /// The message contains one or more DNS answer records. The last record + /// can be modified or truncated. + MidAnswer { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Authority records are being built. + /// + /// The message already contains zero or more DNS authority records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Authorities, + + /// An authority record is being built. + /// + /// The message contains one or more DNS authority records. The last + /// record can be modified or truncated. + MidAuthority { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Additional records are being built. + /// + /// The message already contains zero or more DNS additional records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Additionals, + + /// An additional record is being built. + /// + /// The message contains one or more DNS additional records. The last + /// record can be modified or truncated. + MidAdditional { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, +} + +impl MessageState { + /// The current section index. + /// + /// Questions, answers, authorities, and additionals are mapped to 0, 1, + /// 2, and 3, respectively. + pub const fn section_index(&self) -> u8 { + match self { + Self::Questions | Self::MidQuestion { .. } => 0, + Self::Answers | Self::MidAnswer { .. } => 1, + Self::Authorities | Self::MidAuthority { .. } => 2, + Self::Additionals | Self::MidAdditional { .. } => 3, + } + } +} diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index a79928726..35ec3d145 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -1,45 +1,35 @@ //! Building whole DNS messages. +use core::cell::UnsafeCell; + use crate::new_base::{ - wire::TruncationError, Header, Message, Question, RClass, RType, Record, - TTL, + wire::{ParseBytesByRef, TruncationError}, + Header, Message, Question, RClass, RType, Record, TTL, }; -use super::{BuildIntoMessage, Builder, BuilderContext, RecordBuilder}; +use super::{ + BuildIntoMessage, Builder, BuilderContext, MessageState, QuestionBuilder, + RecordBuilder, +}; //----------- MessageBuilder ------------------------------------------------- /// A builder for a whole DNS message. /// -/// This is subtly different from a regular [`Builder`] -- it does not allow -/// for commits and so can always modify the entire message. It has methods -/// for adding entire questions and records to the message. +/// This is a high-level building interface, offering methods to put together +/// entire questions and records. It directly writes into an allocated buffer +/// (on the stack or the heap). pub struct MessageBuilder<'b> { - /// The underlying [`Builder`]. - /// - /// Its commit point is always 0. - inner: Builder<'b>, + /// The message being constructed. + message: &'b mut Message, + + /// Context for building. + pub(super) context: &'b mut BuilderContext, } //--- Initialization impl<'b> MessageBuilder<'b> { - /// Construct a [`MessageBuilder`] from raw parts. - /// - /// # Safety - /// - /// - `message` and `context` are paired together. - pub unsafe fn from_raw_parts( - message: &'b mut Message, - context: &'b mut BuilderContext, - ) -> Self { - // SAFETY: since 'commit' is 0, no part of the message is immutably - // borrowed; it is thus sound to represent as a mutable borrow. - let inner = - unsafe { Builder::from_raw_parts(message.into(), context, 0) }; - Self { inner } - } - /// Initialize an empty [`MessageBuilder`]. /// /// The message header is left uninitialized. use [`Self::header_mut()`] @@ -53,8 +43,10 @@ impl<'b> MessageBuilder<'b> { buffer: &'b mut [u8], context: &'b mut BuilderContext, ) -> Self { - let inner = Builder::new(buffer, context); - Self { inner } + let message = Message::parse_bytes_by_mut(buffer) + .expect("The caller's buffer is at least 12 bytes big"); + *context = BuilderContext::default(); + Self { message, context } } } @@ -62,21 +54,18 @@ impl<'b> MessageBuilder<'b> { impl<'b> MessageBuilder<'b> { /// The message header. - /// - /// The header can be modified by the builder, and so is only available - /// for a short lifetime. Note that it implements [`Copy`]. pub fn header(&self) -> &Header { - self.inner.header() + &self.message.header } - /// Mutable access to the message header. + /// The message header, mutably. pub fn header_mut(&mut self) -> &mut Header { - self.inner.header_mut() + &mut self.message.header } /// The message built thus far. pub fn message(&self) -> &Message { - self.inner.cur_message() + self.message.slice_to(self.context.size) } /// The message built thus far, mutably. @@ -86,34 +75,26 @@ impl<'b> MessageBuilder<'b> { /// The caller must not modify any compressed names among these bytes. /// This can invalidate name compression state. pub unsafe fn message_mut(&mut self) -> &mut Message { - // SAFETY: Since no bytes are committed, and the rest of the message - // is borrowed mutably for 'self', we can use a mutable reference. - unsafe { self.inner.cur_message_ptr().as_mut() } + self.message.slice_to_mut(self.context.size) } /// The builder context. pub fn context(&self) -> &BuilderContext { - self.inner.context() - } - - /// Decompose this builder into raw parts. - /// - /// This returns the message buffer and the context for this builder. The - /// two are linked, and the builder can be recomposed with - /// [`Self::from_raw_parts()`]. - pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { - let (mut message, context, _commit) = self.inner.into_raw_parts(); - // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed - // mutably for the lifetime 'b. Since the commit point is 0, there is - // no immutably-borrowed content in the message, so it can be turned - // into a regular reference. - (unsafe { message.as_mut() }, context) + self.context } } //--- Interaction impl MessageBuilder<'_> { + /// Reborrow the builder with a shorter lifetime. + pub fn reborrow(&mut self) -> MessageBuilder<'_> { + MessageBuilder { + message: self.message, + context: self.context, + } + } + /// Limit the total message size. /// /// The message will not be allowed to exceed the given size, in bytes. @@ -121,148 +102,124 @@ impl MessageBuilder<'_> { /// or TCP packet size is not considered. If the message already exceeds /// this size, a [`TruncationError`] is returned. /// - /// This size will apply to all builders for this message (including those - /// that delegated to `self`). It will not be automatically revoked if - /// message building fails. - /// /// # Panics /// /// Panics if the given size is less than 12 bytes. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - self.inner.limit_to(size) - } - - /// Append a question. - /// - /// # Panics - /// - /// Panics if the message contains any records (as questions must come - /// before all records). - pub fn append_question( + if 12 + self.context.size <= size { + // Move out of 'message' so that the full lifetime is available. + // See the 'replace_with' and 'take_mut' crates. + debug_assert!(size < 12 + self.message.contents.len()); + let message = unsafe { core::ptr::read(&self.message) }; + // NOTE: Precondition checked, will not panic. + let message = message.slice_to_mut(size - 12); + unsafe { core::ptr::write(&mut self.message, message) }; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Truncate the message. + /// + /// This will remove all message contents and mark it as truncated. + pub fn truncate(&mut self) { + self.message.header.flags = + self.message.header.flags.set_truncated(true); + *self.context = BuilderContext::default(); + } + + /// Obtain a [`Builder`]. + pub(super) fn builder(&mut self, start: usize) -> Builder<'_> { + debug_assert!(start <= self.context.size); + unsafe { + let contents = &mut self.message.contents; + let contents = contents as *mut [u8] as *const UnsafeCell<[u8]>; + Builder::from_raw_parts(&*contents, &mut self.context, start) + } + } + + /// Build a question. + /// + /// If a question is already being built, it will be finished first. If + /// an answer, authority, or additional record has been added, [`None`] is + /// returned instead. + pub fn build_question( &mut self, question: &Question, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - { - // Ensure there are no records present. - assert_eq!(self.header().counts.as_array()[1..], [0, 0, 0]); + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 0 { + // We've progressed into a later section. + return Ok(None); + } - question.build_into_message(self.inner.delegate())?; - self.header_mut().counts.questions += 1; - Ok(()) + self.context.state = MessageState::Questions; + QuestionBuilder::build(self.reborrow(), question).map(Some) } - /// Build an arbitrary record. + /// Build an answer record. /// - /// The record will be added to the specified section (1, 2, or 3, i.e. - /// answers, authorities, and additional records respectively). There - /// must not be any existing records in sections after this one. - pub fn build_record( + /// If a question or answer is already being built, it will be finished + /// first. If an authority or additional record has been added, [`None`] + /// is returned instead. + pub fn build_answer( &mut self, rname: impl BuildIntoMessage, rtype: RType, rclass: RClass, ttl: TTL, - section: u8, - ) -> Result, TruncationError> { - RecordBuilder::new( - self.inner.delegate(), + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 1 { + // We've progressed into a later section. + return Ok(None); + } + + let record = Record { rname, rtype, rclass, ttl, - section, - ) - } - - /// Append an answer record. - /// - /// # Panics - /// - /// Panics if the message contains any authority or additional records. - pub fn append_answer( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - // Ensure there are no authority or additional records present. - assert_eq!(self.header().counts.as_array()[2..], [0, 0]); + rdata: &[] as &[u8], + }; - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.answers += 1; - Ok(()) - } - - /// Build an answer record. - /// - /// # Panics - /// - /// Panics if the message contains any authority or additional records. - pub fn build_answer( - &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, - ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 1) - } - - /// Append an authority record. - /// - /// # Panics - /// - /// Panics if the message contains any additional records. - pub fn append_authority( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - // Ensure there are no additional records present. - assert_eq!(self.header().counts.as_array()[3..], [0]); - - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.authorities += 1; - Ok(()) + self.context.state = MessageState::Answers; + RecordBuilder::build(self.reborrow(), &record).map(Some) } /// Build an authority record. /// - /// # Panics - /// - /// Panics if the message contains any additional records. + /// If a question, answer, or authority is already being built, it will be + /// finished first. If an additional record has been added, [`None`] is + /// returned instead. pub fn build_authority( &mut self, rname: impl BuildIntoMessage, rtype: RType, rclass: RClass, ttl: TTL, - ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 2) - } + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 2 { + // We've progressed into a later section. + return Ok(None); + } - /// Append an additional record. - pub fn append_additional( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.additional += 1; - Ok(()) + let record = Record { + rname, + rtype, + rclass, + ttl, + rdata: &[] as &[u8], + }; + + self.context.state = MessageState::Authorities; + RecordBuilder::build(self.reborrow(), &record).map(Some) } /// Build an additional record. + /// + /// If a question or record is already being built, it will be finished + /// first. Note that it is always possible to add an additional record to + /// a message. pub fn build_additional( &mut self, rname: impl BuildIntoMessage, @@ -270,6 +227,15 @@ impl MessageBuilder<'_> { rclass: RClass, ttl: TTL, ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 3) + let record = Record { + rname, + rtype, + rclass, + ttl, + rdata: &[] as &[u8], + }; + + self.context.state = MessageState::Additionals; + RecordBuilder::build(self.reborrow(), &record) } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index e723ef521..ba62062b6 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -4,8 +4,6 @@ //! but it is not specialized to DNS messages. This module provides that //! specialization within an ergonomic interface. //! -//! # The High-Level Interface -//! //! The core of the high-level interface is [`MessageBuilder`]. It provides //! the most intuitive methods for appending whole questions and records. //! @@ -42,58 +40,25 @@ //! qtype: QType::A, //! qclass: QClass::IN, //! }; -//! builder.append_question(&question).unwrap(); +//! let _ = builder.build_question(&question).unwrap().unwrap(); //! //! // Use the built message. //! let message = builder.message(); //! # let _ = message; //! ``` -//! -//! # The Low-Level Interface -//! -//! [`Builder`] is a powerful low-level interface that can be used to build -//! DNS messages. It implements atomic building and name compression, and is -//! the foundation of [`MessageBuilder`]. -//! -//! The [`Builder`] interface does not know about questions and records; it is -//! only capable of appending simple bytes and compressing domain names. Its -//! access to the message buffer is limited; it can only append, modify, or -//! truncate the message up to a certain point (all data before that point is -//! immutable). Special attention is given to the message header, as it can -//! be modified at any point in the message building process. -//! -//! ``` -//! use domain::new_base::build::{BuilderContext, Builder, BuildIntoMessage}; -//! use domain::new_rdata::A; -//! -//! // Construct a builder for a particular buffer. -//! let mut buffer = [0u8; 20]; -//! let mut context = BuilderContext::default(); -//! let mut builder = Builder::new(&mut buffer, &mut context); -//! -//! // Try appending some raw bytes to the builder. -//! builder.append_bytes(b"hi! ").unwrap(); -//! assert_eq!(builder.appended(), b"hi! "); -//! -//! // Try appending some structured content to the builder. -//! A::from(std::net::Ipv4Addr::new(127, 0, 0, 1)) -//! .build_into_message(builder.delegate()) -//! .unwrap(); -//! assert_eq!(builder.appended(), b"hi! \x7F\x00\x00\x01"); -//! -//! // Finish using the builder. -//! builder.commit(); -//! -//! // Note: the first 12 bytes hold the message header. -//! assert_eq!(&buffer[12..20], b"hi! \x7F\x00\x00\x01"); -//! ``` mod builder; -pub use builder::{Builder, BuilderContext}; +pub use builder::Builder; + +mod context; +pub use context::{BuilderContext, MessageState}; mod message; pub use message::MessageBuilder; +mod question; +pub use question::QuestionBuilder; + mod record; pub use record::RecordBuilder; diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs new file mode 100644 index 000000000..7c8c8b1e8 --- /dev/null +++ b/src/new_base/build/question.rs @@ -0,0 +1,106 @@ +//! Building DNS questions. + +use crate::new_base::{ + name::UnparsedName, + parse::ParseMessageBytes, + wire::{ParseBytes, TruncationError}, + QClass, QType, Question, +}; + +use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; + +//----------- QuestionBuilder ------------------------------------------------ + +/// A DNS question builder. +pub struct QuestionBuilder<'b> { + /// The underlying message builder. + builder: MessageBuilder<'b>, + + /// The offset of the question name. + name: u16, +} + +//--- Construction + +impl<'b> QuestionBuilder<'b> { + /// Build a [`Question`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b>, + question: &Question, + ) -> Result { + // TODO: Require that the QNAME serialize correctly? + let start = builder.context.size; + question.build_into_message(builder.builder(start))?; + let name = start.try_into().expect("Messages are at most 64KiB"); + builder.context.state = MessageState::MidQuestion { name }; + Ok(Self { builder, name }) + } + + /// Reconstruct a [`QuestionBuilder`] from raw parts. + /// + /// # Safety + /// + /// `builder.message().contents[name..]` must represent a valid + /// [`Question`] in the wire format. + pub unsafe fn from_raw_parts( + builder: MessageBuilder<'b>, + name: u16, + ) -> Self { + Self { builder, name } + } +} + +//--- Inspection + +impl<'b> QuestionBuilder<'b> { + /// The (unparsed) question name. + pub fn qname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = &contents[..contents.len() - 4]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The question was serialized correctly") + } + + /// The question type. + pub fn qtype(&self) -> QType { + let contents = &self.builder.message().contents; + QType::parse_bytes(&contents[contents.len() - 4..contents.len() - 2]) + .expect("The question was serialized correctly") + } + + /// The question class. + pub fn qclass(&self) -> QClass { + let contents = &self.builder.message().contents; + QClass::parse_bytes(&contents[contents.len() - 2..]) + .expect("The question was serialized correctly") + } + + /// Deconstruct this [`QuestionBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16) { + (self.builder, self.name) + } +} + +//--- Interaction + +impl<'b> QuestionBuilder<'b> { + /// Commit this question. + /// + /// The builder will be consumed, and the question will be committed so + /// that it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + self.builder.context.state = MessageState::Questions; + BuildCommitted + } + + /// Stop building and remove this question. + /// + /// The builder will be consumed, and the question will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + self.builder.context.state = MessageState::Questions; + } +} diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index aac0857c3..e1ef789ab 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -1,176 +1,230 @@ //! Building DNS records. +use core::{mem::ManuallyDrop, ptr}; + use crate::new_base::{ - name::RevName, - wire::{AsBytes, TruncationError}, - Header, Message, RClass, RType, TTL, + name::UnparsedName, + parse::ParseMessageBytes, + wire::{AsBytes, ParseBytes, SizePrefixed, TruncationError}, + RClass, RType, Record, TTL, }; -use super::{BuildCommitted, BuildIntoMessage, Builder}; +use super::{ + BuildCommitted, BuildIntoMessage, Builder, MessageBuilder, MessageState, +}; -//----------- RecordBuilder -------------------------------------------------- +//----------- RecordBuilder ------------------------------------------------ -/// A builder for a DNS record. -/// -/// This is used to incrementally build the data for a DNS record. It can be -/// constructed using [`MessageBuilder::build_answer()`] etc. -/// -/// [`MessageBuilder::build_answer()`]: super::MessageBuilder::build_answer() +/// A DNS record builder. pub struct RecordBuilder<'b> { - /// The underlying [`Builder`]. - /// - /// Its commit point lies at the beginning of the record. - inner: Builder<'b>, + /// The underlying message builder. + builder: MessageBuilder<'b>, - /// The position of the record data. - /// - /// This is an offset from the message contents. - start: usize, + /// The offset of the record name. + name: u16, - /// The section the record is a part of. - /// - /// The appropriate section count will be incremented on completion. - section: u8, + /// The offset of the record data. + data: u16, } -//--- Initialization +//--- Construction impl<'b> RecordBuilder<'b> { - /// Construct a [`RecordBuilder`] from raw parts. + /// Build a [`Record`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b>, + record: &Record, + ) -> Result + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Build the record and remember important positions. + let start = builder.context.size; + let (name, data) = { + let name = start.try_into().expect("Messages are at most 64KiB"); + let mut b = builder.builder(start); + record.rname.build_into_message(b.delegate())?; + b.append_bytes(&record.rtype.as_bytes())?; + b.append_bytes(&record.rclass.as_bytes())?; + b.append_bytes(&record.ttl.as_bytes())?; + let size = b.context().size; + SizePrefixed::new(&record.rdata) + .build_into_message(b.delegate())?; + let data = + (size + 2).try_into().expect("Messages are at most 64KiB"); + (name, data) + }; + + // Update the message state. + match builder.context.state { + ref mut state @ MessageState::Answers => { + *state = MessageState::MidAnswer { name, data }; + } + + ref mut state @ MessageState::Authorities => { + *state = MessageState::MidAuthority { name, data }; + } + + ref mut state @ MessageState::Additionals => { + *state = MessageState::MidAdditional { name, data }; + } + + _ => unreachable!(), + } + + Ok(Self { + builder, + name, + data, + }) + } + + /// Reconstruct a [`RecordBuilder`] from raw parts. /// /// # Safety /// - /// - `builder`, `start`, and `section` are paired together. + /// `builder.message().contents[name..]` must represent a valid + /// [`Record`] in the wire format. `contents[data..]` must represent the + /// record data (i.e. immediately after the record data size field). pub unsafe fn from_raw_parts( - builder: Builder<'b>, - start: usize, - section: u8, + builder: MessageBuilder<'b>, + name: u16, + data: u16, ) -> Self { Self { - inner: builder, - start, - section, + builder, + name, + data, } } - - /// Initialize a new [`RecordBuilder`]. - /// - /// A new record with the given name, type, and class will be created. - /// The returned builder can be used to add data for the record. - /// - /// The count for the specified section (1, 2, or 3, i.e. answers, - /// authorities, and additional records respectively) will be incremented - /// when the builder finishes successfully. - pub fn new( - mut builder: Builder<'b>, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, - section: u8, - ) -> Result { - debug_assert_eq!(builder.appended(), &[] as &[u8]); - debug_assert!((1..4).contains(§ion)); - - assert!(builder - .header() - .counts - .as_array() - .iter() - .skip(1 + section as usize) - .all(|&c| c == 0)); - - // Build the record header. - rname.build_into_message(builder.delegate())?; - builder.append_bytes(rtype.as_bytes())?; - builder.append_bytes(rclass.as_bytes())?; - builder.append_bytes(ttl.as_bytes())?; - builder.append_bytes(&0u16.to_be_bytes())?; - let start = builder.appended().len(); - - // Set up the builder. - Ok(Self { - inner: builder, - start, - section, - }) - } } //--- Inspection impl<'b> RecordBuilder<'b> { - /// The message header. - pub fn header(&self) -> &Header { - self.inner.header() + /// The (unparsed) record name. + pub fn rname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = &contents[..contents.len() - 4]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The record was serialized correctly") } - /// The message without this record. - pub fn message(&self) -> &Message { - self.inner.message() + /// The record type. + pub fn rtype(&self) -> RType { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + RType::parse_bytes(&contents[0..2]) + .expect("The record was serialized correctly") } - /// The record data appended thus far. - pub fn data(&self) -> &[u8] { - &self.inner.appended()[self.start..] + /// The record class. + pub fn rclass(&self) -> RClass { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + RClass::parse_bytes(&contents[2..4]) + .expect("The record was serialized correctly") } - /// Decompose this builder into raw parts. - /// - /// This returns the underlying builder, the offset of the record data in - /// the record, and the section number for this record (1, 2, or 3). The - /// builder can be recomposed with [`Self::from_raw_parts()`]. - pub fn into_raw_parts(self) -> (Builder<'b>, usize, u8) { - (self.inner, self.start, self.section) + /// The TTL. + pub fn ttl(&self) -> TTL { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + TTL::parse_bytes(&contents[4..8]) + .expect("The record was serialized correctly") + } + + /// The record data built thus far. + pub fn rdata(&self) -> &[u8] { + &self.builder.message().contents[usize::from(self.data)..] + } + + /// Deconstruct this [`RecordBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16, u16) { + let (name, data) = (self.name, self.data); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let builder = unsafe { ptr::read(ptr::addr_of!((*this).builder)) }; + (builder, name, data) } } //--- Interaction -impl RecordBuilder<'_> { - /// Finish the record. +impl<'b> RecordBuilder<'b> { + /// Commit this record. /// - /// The respective section count will be incremented. The builder will be - /// consumed and the record will be committed. - pub fn finish(mut self) -> BuildCommitted { - // Increment the appropriate section count. - self.inner.header_mut().counts.as_array_mut() - [self.section as usize] += 1; - - // Set the record data length. - let size = self.inner.appended().len() - self.start; - let size = u16::try_from(size) - .expect("Record data must be smaller than 64KiB"); - // SAFETY: The record data size is not part of a compressed name. - let appended = unsafe { self.inner.appended_mut() }; - appended[self.start - 2..self.start] - .copy_from_slice(&size.to_be_bytes()); - - self.inner.commit() + /// The builder will be consumed, and the record will be committed so that + /// it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + match self.builder.context.state { + ref mut state @ MessageState::MidAnswer { .. } => { + *state = MessageState::Answers; + } + + ref mut state @ MessageState::MidAuthority { .. } => { + *state = MessageState::Authorities; + } + + ref mut state @ MessageState::MidAdditional { .. } => { + *state = MessageState::Additionals; + } + + _ => unreachable!(), + } + + // NOTE: The record data size will be fixed on drop. + BuildCommitted } - /// Delegate to a new builder. + /// Stop building and remove this record. /// - /// Any content committed by the builder will be added as record data. - pub fn delegate(&mut self) -> Builder<'_> { - self.inner.delegate() + /// The builder will be consumed, and the record will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + match self.builder.context.state { + ref mut state @ MessageState::MidAnswer { .. } => { + *state = MessageState::Answers; + } + + ref mut state @ MessageState::MidAuthority { .. } => { + *state = MessageState::Authorities; + } + + ref mut state @ MessageState::MidAdditional { .. } => { + *state = MessageState::Additionals; + } + + _ => unreachable!(), + } + + // NOTE: The drop glue is a no-op. } - /// Append some bytes. - /// - /// No name compression will be performed. - pub fn append_bytes( - &mut self, - bytes: &[u8], - ) -> Result<(), TruncationError> { - self.inner.append_bytes(bytes) + /// Delegate further building of the record data to a new [`Builder`]. + pub fn delegate(&mut self) -> Builder<'_> { + let offset = self.builder.context.size; + self.builder.builder(offset) } +} - /// Compress and append a domain name. - pub fn append_name( - &mut self, - name: &RevName, - ) -> Result<(), TruncationError> { - self.inner.append_name(name) +//--- Drop + +impl Drop for RecordBuilder<'_> { + fn drop(&mut self) { + // Fixup the record data size so the overall message builder is valid. + let size = self.builder.context.size as u16; + if self.data <= size { + // SAFETY: Only the record data size field is being modified. + let message = unsafe { self.builder.message_mut() }; + let data = usize::from(self.data); + message.contents[data - 2..data] + .copy_from_slice(&size.to_be_bytes()); + } } } diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 5ac9effa9..9053c6431 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -248,13 +248,13 @@ impl BuildIntoMessage for SizePrefixed { &self, mut builder: build::Builder<'_>, ) -> BuildResult { - assert_eq!(builder.appended(), &[] as &[u8]); + assert_eq!(builder.uncommitted(), &[] as &[u8]); builder.append_bytes(&0u16.to_be_bytes())?; self.data.build_into_message(builder.delegate())?; - let size = builder.appended().len() - 2; + let size = builder.uncommitted().len() - 2; let size = u16::try_from(size).expect("the data never exceeds 64KiB"); // SAFETY: A 'U16' is being modified, not a domain name. - let size_buf = unsafe { &mut builder.appended_mut()[0..2] }; + let size_buf = unsafe { &mut builder.uncommitted_mut()[0..2] }; size_buf.copy_from_slice(&size.to_be_bytes()); Ok(builder.commit()) } From 41db42a76c9c8eff277ede96619bb37ab23269f8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:05:44 +0100 Subject: [PATCH 098/111] [new_base/build/message] Add resumption methods --- src/new_base/build/context.rs | 2 +- src/new_base/build/message.rs | 75 ++++++++++++++++++++++++++++++++++ src/new_base/build/question.rs | 5 +++ src/new_base/build/record.rs | 5 +++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index 2f7f43da1..bd423b4de 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -26,7 +26,7 @@ pub struct BuilderContext { /// additionals. [`MessageState`] remembers the start position of the last /// question or record in the message, allowing it to be modifying or removed /// (for additional flexibility in the building process). -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub enum MessageState { /// Questions are being built. /// diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 35ec3d145..288b4eefb 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -157,6 +157,24 @@ impl MessageBuilder<'_> { QuestionBuilder::build(self.reborrow(), question).map(Some) } + /// Resume building a question. + /// + /// If a question was built (using [`build_question()`]) but the returned + /// builder was neither committed nor canceled, the question builder will + /// be recovered and returned. + /// + /// [`build_question()`]: Self::build_question() + pub fn resume_question(&mut self) -> Option> { + let MessageState::MidQuestion { name } = self.context.state else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + QuestionBuilder::from_raw_parts(self.reborrow(), name) + }) + } + /// Build an answer record. /// /// If a question or answer is already being built, it will be finished @@ -186,6 +204,25 @@ impl MessageBuilder<'_> { RecordBuilder::build(self.reborrow(), &record).map(Some) } + /// Resume building an answer record. + /// + /// If an answer record was built (using [`build_answer()`]) but the + /// returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_answer()`]: Self::build_answer() + pub fn resume_answer(&mut self) -> Option> { + let MessageState::MidAnswer { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + /// Build an authority record. /// /// If a question, answer, or authority is already being built, it will be @@ -215,6 +252,25 @@ impl MessageBuilder<'_> { RecordBuilder::build(self.reborrow(), &record).map(Some) } + /// Resume building an authority record. + /// + /// If an authority record was built (using [`build_authority()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_authority()`]: Self::build_authority() + pub fn resume_authority(&mut self) -> Option> { + let MessageState::MidAuthority { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + /// Build an additional record. /// /// If a question or record is already being built, it will be finished @@ -238,4 +294,23 @@ impl MessageBuilder<'_> { self.context.state = MessageState::Additionals; RecordBuilder::build(self.reborrow(), &record) } + + /// Resume building an additional record. + /// + /// If an additional record was built (using [`build_additional()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_additional()`]: Self::build_additional() + pub fn resume_additional(&mut self) -> Option> { + let MessageState::MidAdditional { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } } diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 7c8c8b1e8..d16921eff 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -12,6 +12,11 @@ use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; //----------- QuestionBuilder ------------------------------------------------ /// A DNS question builder. +/// +/// A [`QuestionBuilder`] provides control over a DNS question that has been +/// appended to a message (using a [`MessageBuilder`]). It can be used to +/// inspect the question's fields, to replace it with a new question, and to +/// commit (finish building) or cancel (remove) the question. pub struct QuestionBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b>, diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index e1ef789ab..66bf96a09 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -16,6 +16,11 @@ use super::{ //----------- RecordBuilder ------------------------------------------------ /// A DNS record builder. +/// +/// A [`RecordBuilder`] provides access to a record that has been appended to +/// a DNS message (using a [`MessageBuilder`]). It can be used to inspect the +/// record, to (re)write the record data, and to commit (finish building) or +/// cancel (remove) the record. pub struct RecordBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b>, From 2fcd114c2f26f8cfc4eeb06c7300fca691615a9d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:15:57 +0100 Subject: [PATCH 099/111] [new_base/build] Test 'MessageBuilder' and fix some bugs --- src/new_base/build/message.rs | 140 +++++++++++++++++++++++++++++++++ src/new_base/build/question.rs | 2 +- src/new_base/build/record.rs | 11 ++- src/new_base/record.rs | 32 ++++++++ 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 288b4eefb..28967d448 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -314,3 +314,143 @@ impl MessageBuilder<'_> { }) } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use crate::{ + new_base::{ + build::{BuildIntoMessage, BuilderContext, MessageState}, + name::RevName, + QClass, QType, Question, RClass, RType, TTL, + }, + new_rdata::A, + }; + + use super::MessageBuilder; + + const WWW_EXAMPLE_ORG: &RevName = unsafe { + RevName::from_bytes_unchecked(b"\x00\x03org\x07example\x03www") + }; + + #[test] + fn new() { + let mut buffer = [0u8; 12]; + let mut context = BuilderContext::default(); + + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + assert_eq!(&builder.message().contents, &[] as &[u8]); + assert_eq!(unsafe { &builder.message_mut().contents }, &[] as &[u8]); + assert_eq!(builder.context().size, 0); + assert_eq!(builder.context().state, MessageState::Questions); + } + + #[test] + fn build_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let qb = builder.build_question(&question).unwrap().unwrap(); + + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + + let state = MessageState::MidQuestion { name: 0 }; + assert_eq!(builder.context().state, state); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01"; + assert_eq!(&builder.message().contents, contents); + } + + #[test] + fn resume_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let _ = builder.build_question(&question).unwrap().unwrap(); + + let qb = builder.resume_question().unwrap(); + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + } + + #[test] + fn build_record() { + let mut buffer = [0u8; 43]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + { + let mut rb = builder + .build_answer( + WWW_EXAMPLE_ORG, + RType::A, + RClass::IN, + TTL::from(42), + ) + .unwrap() + .unwrap(); + + assert_eq!( + rb.rname().as_bytes(), + b"\x03www\x07example\x03org\x00" + ); + assert_eq!(rb.rtype(), RType::A); + assert_eq!(rb.rclass(), RClass::IN); + assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rdata(), b""); + + assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); + + let rdata = A { + octets: [127, 0, 0, 1], + }; + rdata.build_into_message(rb.delegate()).unwrap(); + assert_eq!(rb.rdata(), b"\x7F\x00\x00\x01"); + } + + let state = MessageState::MidAnswer { name: 0, data: 27 }; + assert_eq!(builder.context().state, state); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01"; + assert_eq!(&builder.message().contents, contents.as_slice()); + } + + #[test] + fn resume_record() { + let mut buffer = [0u8; 39]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let _ = builder + .build_answer( + WWW_EXAMPLE_ORG, + RType::A, + RClass::IN, + TTL::from(42), + ) + .unwrap() + .unwrap(); + + let rb = builder.resume_answer().unwrap(); + assert_eq!(rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(rb.rtype(), RType::A); + assert_eq!(rb.rclass(), RClass::IN); + assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rdata(), b""); + } +} diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index d16921eff..72ae6bac0 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -64,7 +64,7 @@ impl<'b> QuestionBuilder<'b> { /// The (unparsed) question name. pub fn qname(&self) -> &UnparsedName { let contents = &self.builder.message().contents; - let contents = &contents[..contents.len() - 4]; + let contents = &contents[usize::from(self.name)..contents.len() - 4]; <&UnparsedName>::parse_message_bytes(contents, self.name.into()) .expect("The question was serialized correctly") } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 66bf96a09..707539593 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -61,6 +61,7 @@ impl<'b> RecordBuilder<'b> { .build_into_message(b.delegate())?; let data = (size + 2).try_into().expect("Messages are at most 64KiB"); + b.commit(); (name, data) }; @@ -114,7 +115,8 @@ impl<'b> RecordBuilder<'b> { /// The (unparsed) record name. pub fn rname(&self) -> &UnparsedName { let contents = &self.builder.message().contents; - let contents = &contents[..contents.len() - 4]; + let contents = + &contents[usize::from(self.name)..usize::from(self.data) - 10]; <&UnparsedName>::parse_message_bytes(contents, self.name.into()) .expect("The record was serialized correctly") } @@ -122,7 +124,7 @@ impl<'b> RecordBuilder<'b> { /// The record type. pub fn rtype(&self) -> RType { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; RType::parse_bytes(&contents[0..2]) .expect("The record was serialized correctly") } @@ -130,7 +132,7 @@ impl<'b> RecordBuilder<'b> { /// The record class. pub fn rclass(&self) -> RClass { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; RClass::parse_bytes(&contents[2..4]) .expect("The record was serialized correctly") } @@ -138,7 +140,7 @@ impl<'b> RecordBuilder<'b> { /// The TTL. pub fn ttl(&self) -> TTL { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; TTL::parse_bytes(&contents[4..8]) .expect("The record was serialized correctly") } @@ -228,6 +230,7 @@ impl Drop for RecordBuilder<'_> { // SAFETY: Only the record data size field is being modified. let message = unsafe { self.builder.message_mut() }; let data = usize::from(self.data); + let size = size - self.data; message.contents[data - 2..data] .copy_from_slice(&size.to_be_bytes()); } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 0c0e6fa3c..f3663df60 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -273,6 +273,22 @@ pub struct RClass { pub code: U16, } +//--- Associated Constants + +impl RClass { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The Internet class. + pub const IN: Self = Self::new(1); + + /// The CHAOS class. + pub const CH: Self = Self::new(3); +} + //----------- TTL ------------------------------------------------------------ /// How long a record can be cached. @@ -298,6 +314,22 @@ pub struct TTL { pub value: U32, } +//--- Conversion to and from integers + +impl From for TTL { + fn from(value: u32) -> Self { + Self { + value: U32::new(value), + } + } +} + +impl From for u32 { + fn from(value: TTL) -> Self { + value.value.get() + } +} + //----------- ParseRecordData ------------------------------------------------ /// Parsing DNS record data. From 074486c8be67994ea67506419f2c8ac552c43628 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:40:39 +0100 Subject: [PATCH 100/111] [new_base/build] Track section counts while building --- src/new_base/build/context.rs | 47 ++++++++++ src/new_base/build/message.rs | 158 ++++++++++++++++++--------------- src/new_base/build/mod.rs | 19 +++- src/new_base/build/question.rs | 1 + src/new_base/build/record.rs | 35 ++------ 5 files changed, 155 insertions(+), 105 deletions(-) diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index bd423b4de..112330123 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -2,6 +2,8 @@ //----------- BuilderContext ------------------------------------------------- +use crate::new_base::SectionCounts; + /// Context for building a DNS message. /// /// This type holds auxiliary information necessary for building DNS messages, @@ -132,4 +134,49 @@ impl MessageState { Self::Additionals | Self::MidAdditional { .. } => 3, } } + + /// Whether a question or record is being built. + pub const fn mid_component(&self) -> bool { + match self { + Self::MidQuestion { .. } => true, + Self::MidAnswer { .. } => true, + Self::MidAuthority { .. } => true, + Self::MidAdditional { .. } => true, + _ => false, + } + } + + /// Commit a question or record and update the section counts. + pub fn commit(&mut self, counts: &mut SectionCounts) { + match self { + Self::MidQuestion { .. } => { + counts.questions += 1; + *self = Self::Questions; + } + Self::MidAnswer { .. } => { + counts.answers += 1; + *self = Self::Answers; + } + Self::MidAuthority { .. } => { + counts.authorities += 1; + *self = Self::Authorities; + } + Self::MidAdditional { .. } => { + counts.additional += 1; + *self = Self::Additionals; + } + _ => {} + } + } + + /// Cancel a question or record. + pub fn cancel(&mut self) { + match self { + Self::MidQuestion { .. } => *self = Self::Questions, + Self::MidAnswer { .. } => *self = Self::Answers, + Self::MidAuthority { .. } => *self = Self::Authorities, + Self::MidAdditional { .. } => *self = Self::Additionals, + _ => {} + } + } } diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 28967d448..3101ecd76 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -4,7 +4,7 @@ use core::cell::UnsafeCell; use crate::new_base::{ wire::{ParseBytesByRef, TruncationError}, - Header, Message, Question, RClass, RType, Record, TTL, + Header, Message, Question, Record, }; use super::{ @@ -21,7 +21,7 @@ use super::{ /// (on the stack or the heap). pub struct MessageBuilder<'b> { /// The message being constructed. - message: &'b mut Message, + pub(super) message: &'b mut Message, /// Context for building. pub(super) context: &'b mut BuilderContext, @@ -148,12 +148,18 @@ impl MessageBuilder<'_> { &mut self, question: &Question, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 0 { + let state = &mut self.context.state; + if state.section_index() > 0 { // We've progressed into a later section. return Ok(None); } - self.context.state = MessageState::Questions; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Questions; QuestionBuilder::build(self.reborrow(), question).map(Some) } @@ -180,28 +186,23 @@ impl MessageBuilder<'_> { /// If a question or answer is already being built, it will be finished /// first. If an authority or additional record has been added, [`None`] /// is returned instead. - pub fn build_answer( + pub fn build_answer( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 1 { + let state = &mut self.context.state; + if state.section_index() > 1 { // We've progressed into a later section. return Ok(None); } - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Answers; - RecordBuilder::build(self.reborrow(), &record).map(Some) + *state = MessageState::Answers; + RecordBuilder::build(self.reborrow(), record).map(Some) } /// Resume building an answer record. @@ -228,28 +229,23 @@ impl MessageBuilder<'_> { /// If a question, answer, or authority is already being built, it will be /// finished first. If an additional record has been added, [`None`] is /// returned instead. - pub fn build_authority( + pub fn build_authority( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 2 { + let state = &mut self.context.state; + if state.section_index() > 2 { // We've progressed into a later section. return Ok(None); } - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Authorities; - RecordBuilder::build(self.reborrow(), &record).map(Some) + *state = MessageState::Authorities; + RecordBuilder::build(self.reborrow(), record).map(Some) } /// Resume building an authority record. @@ -276,23 +272,18 @@ impl MessageBuilder<'_> { /// If a question or record is already being built, it will be finished /// first. Note that it is always possible to add an additional record to /// a message. - pub fn build_additional( + pub fn build_additional( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result, TruncationError> { - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + let state = &mut self.context.state; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Additionals; - RecordBuilder::build(self.reborrow(), &record) + *state = MessageState::Additionals; + RecordBuilder::build(self.reborrow(), record) } /// Resume building an additional record. @@ -323,7 +314,9 @@ mod test { new_base::{ build::{BuildIntoMessage, BuilderContext, MessageState}, name::RevName, - QClass, QType, Question, RClass, RType, TTL, + wire::U16, + QClass, QType, Question, RClass, RType, Record, SectionCounts, + TTL, }, new_rdata::A, }; @@ -366,6 +359,7 @@ mod test { let state = MessageState::MidQuestion { name: 0 }; assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01"; assert_eq!(&builder.message().contents, contents); } @@ -387,6 +381,15 @@ mod test { assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); assert_eq!(qb.qtype(), question.qtype); assert_eq!(qb.qclass(), question.qclass); + + qb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + questions: U16::new(1), + ..Default::default() + } + ); } #[test] @@ -395,24 +398,24 @@ mod test { let mut context = BuilderContext::default(); let mut builder = MessageBuilder::new(&mut buffer, &mut context); + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + { - let mut rb = builder - .build_answer( - WWW_EXAMPLE_ORG, - RType::A, - RClass::IN, - TTL::from(42), - ) - .unwrap() - .unwrap(); + let mut rb = builder.build_answer(&record).unwrap().unwrap(); assert_eq!( rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00" ); - assert_eq!(rb.rtype(), RType::A); - assert_eq!(rb.rclass(), RClass::IN); - assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); assert_eq!(rb.rdata(), b""); assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); @@ -426,6 +429,7 @@ mod test { let state = MessageState::MidAnswer { name: 0, data: 27 }; assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01"; assert_eq!(&builder.message().contents, contents.as_slice()); } @@ -436,21 +440,29 @@ mod test { let mut context = BuilderContext::default(); let mut builder = MessageBuilder::new(&mut buffer, &mut context); - let _ = builder - .build_answer( - WWW_EXAMPLE_ORG, - RType::A, - RClass::IN, - TTL::from(42), - ) - .unwrap() - .unwrap(); + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + let _ = builder.build_answer(&record).unwrap().unwrap(); let rb = builder.resume_answer().unwrap(); assert_eq!(rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00"); - assert_eq!(rb.rtype(), RType::A); - assert_eq!(rb.rclass(), RClass::IN); - assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); assert_eq!(rb.rdata(), b""); + + rb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + answers: U16::new(1), + ..Default::default() + } + ); } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index ba62062b6..871d20571 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -81,13 +81,28 @@ impl BuildIntoMessage for &T { } } -impl BuildIntoMessage for [u8] { +impl BuildIntoMessage for u8 { fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { - builder.append_bytes(self)?; + builder.append_bytes(&[*self])?; Ok(builder.commit()) } } +impl BuildIntoMessage for [T] { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + for elem in self { + elem.build_into_message(builder.delegate())?; + } + Ok(builder.commit()) + } +} + +impl BuildIntoMessage for [T; N] { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + self.as_slice().build_into_message(builder) + } +} + //----------- BuildResult ---------------------------------------------------- /// The result of building into a DNS message. diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 72ae6bac0..680b828fd 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -98,6 +98,7 @@ impl<'b> QuestionBuilder<'b> { /// that it can no longer be removed. pub fn commit(self) -> BuildCommitted { self.builder.context.state = MessageState::Questions; + self.builder.message.header.counts.questions += 1; BuildCommitted } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 707539593..6a27826c3 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -169,21 +169,10 @@ impl<'b> RecordBuilder<'b> { /// The builder will be consumed, and the record will be committed so that /// it can no longer be removed. pub fn commit(self) -> BuildCommitted { - match self.builder.context.state { - ref mut state @ MessageState::MidAnswer { .. } => { - *state = MessageState::Answers; - } - - ref mut state @ MessageState::MidAuthority { .. } => { - *state = MessageState::Authorities; - } - - ref mut state @ MessageState::MidAdditional { .. } => { - *state = MessageState::Additionals; - } - - _ => unreachable!(), - } + self.builder + .context + .state + .commit(&mut self.builder.message.header.counts); // NOTE: The record data size will be fixed on drop. BuildCommitted @@ -194,21 +183,7 @@ impl<'b> RecordBuilder<'b> { /// The builder will be consumed, and the record will be removed. pub fn cancel(self) { self.builder.context.size = self.name.into(); - match self.builder.context.state { - ref mut state @ MessageState::MidAnswer { .. } => { - *state = MessageState::Answers; - } - - ref mut state @ MessageState::MidAuthority { .. } => { - *state = MessageState::Authorities; - } - - ref mut state @ MessageState::MidAdditional { .. } => { - *state = MessageState::Additionals; - } - - _ => unreachable!(), - } + self.builder.context.state.cancel(); // NOTE: The drop glue is a no-op. } From 9a5e19313bbea2a69d7124573affdff497742184 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:40:52 +0100 Subject: [PATCH 101/111] [new_base/wire/parse] Fix miscount of doc test --- src/new_base/wire/parse.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 2a8da6642..2e517c197 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -189,7 +189,7 @@ pub unsafe trait ParseBytesByRef { /// may be provided. Until then, it should be implemented using one of /// the following expressions: /// - /// ```ignore + /// ```text /// fn ptr_with_address( /// &self, /// addr: *const (), From 1a54da1e46d628638d79d22c655a87514be34435 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:13:18 +0100 Subject: [PATCH 102/111] [new_base/wire/parse] Support parsing into arrays --- src/new_base/wire/parse.rs | 54 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 2e517c197..5ab9bedd5 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -1,6 +1,7 @@ //! Parsing bytes in the basic network format. use core::fmt; +use core::mem::MaybeUninit; //----------- ParseBytes ----------------------------------------------------- @@ -29,6 +30,15 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { } } +impl<'a, T: SplitBytes<'a>, const N: usize> ParseBytes<'a> for [T; N] { + fn parse_bytes(bytes: &'a [u8]) -> Result { + match Self::split_bytes(bytes) { + Ok((this, &[])) => Ok(this), + _ => Err(ParseError), + } + } +} + /// Deriving [`ParseBytes`] automatically. /// /// [`ParseBytes`] can be derived on `struct`s (not `enum`s or `union`s). All @@ -86,6 +96,50 @@ impl<'a> SplitBytes<'a> for u8 { } } +impl<'a, T: SplitBytes<'a>, const N: usize> SplitBytes<'a> for [T; N] { + fn split_bytes( + mut bytes: &'a [u8], + ) -> Result<(Self, &'a [u8]), ParseError> { + // TODO: Rewrite when either 'array_try_map' or 'try_array_from_fn' + // is stabilized. + + /// A guard for dropping initialized elements on panic / failure. + struct Guard { + buffer: [MaybeUninit; N], + initialized: usize, + } + + impl Drop for Guard { + fn drop(&mut self) { + for elem in &mut self.buffer[..self.initialized] { + // SAFETY: The first 'initialized' elems are initialized. + unsafe { elem.assume_init_drop() }; + } + } + } + + let mut guard = Guard:: { + buffer: [const { MaybeUninit::uninit() }; N], + initialized: 0, + }; + + while guard.initialized < N { + let (elem, rest) = T::split_bytes(bytes)?; + guard.buffer[guard.initialized].write(elem); + bytes = rest; + guard.initialized += 1; + } + + // Disable the guard since we're moving data out now. + guard.initialized = 0; + + // SAFETY: '[MaybeUninit; N]' and '[T; N]' have the same layout, + // because 'MaybeUninit' and 'T' have the same layout, because it + // is documented in the standard library. + Ok((unsafe { core::mem::transmute_copy(&guard.buffer) }, bytes)) + } +} + /// Deriving [`SplitBytes`] automatically. /// /// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All From 00aaaf77a8e07a42d7bd84bf39474de71c0cd510 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:16:40 +0100 Subject: [PATCH 103/111] [new_base/build] Accept clippy suggestions --- src/new_base/build/builder.rs | 6 +++--- src/new_base/build/context.rs | 14 +++++++------- src/new_base/build/message.rs | 4 ++-- src/new_base/build/question.rs | 2 +- src/new_base/build/record.rs | 8 ++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index edcc9b543..34c415ce3 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -177,7 +177,7 @@ impl<'b> Builder<'b> { pub fn uncommitted(&self) -> &[u8] { let message = self.contents.get().cast::().cast_const(); // SAFETY: It is guaranteed that 'start <= message.len()'. - let message = unsafe { message.offset(self.start as isize) }; + let message = unsafe { message.add(self.start) }; let size = self.context.size - self.start; // SAFETY: 'message[start..]' is mutably borrowed. unsafe { slice::from_raw_parts(message, size) } @@ -192,7 +192,7 @@ impl<'b> Builder<'b> { pub unsafe fn uncommitted_mut(&mut self) -> &mut [u8] { let message = self.contents.get().cast::(); // SAFETY: It is guaranteed that 'start <= message.len()'. - let message = unsafe { message.offset(self.start as isize) }; + let message = unsafe { message.add(self.start) }; let size = self.context.size - self.start; // SAFETY: 'message[start..]' is mutably borrowed. unsafe { slice::from_raw_parts_mut(message, size) } @@ -206,7 +206,7 @@ impl<'b> Builder<'b> { pub fn uninitialized(&mut self) -> &mut [u8] { let message = self.contents.get().cast::(); // SAFETY: It is guaranteed that 'size <= message.len()'. - let message = unsafe { message.offset(self.context.size as isize) }; + let message = unsafe { message.add(self.context.size) }; let size = self.max_size() - self.context.size; // SAFETY: 'message[size..]' is mutably borrowed. unsafe { slice::from_raw_parts_mut(message, size) } diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index 112330123..e62ad265b 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -137,13 +137,13 @@ impl MessageState { /// Whether a question or record is being built. pub const fn mid_component(&self) -> bool { - match self { - Self::MidQuestion { .. } => true, - Self::MidAnswer { .. } => true, - Self::MidAuthority { .. } => true, - Self::MidAdditional { .. } => true, - _ => false, - } + matches!( + self, + Self::MidQuestion { .. } + | Self::MidAnswer { .. } + | Self::MidAuthority { .. } + | Self::MidAdditional { .. } + ) } /// Commit a question or record and update the section counts. diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 3101ecd76..5e969115f 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -52,7 +52,7 @@ impl<'b> MessageBuilder<'b> { //--- Inspection -impl<'b> MessageBuilder<'b> { +impl MessageBuilder<'_> { /// The message header. pub fn header(&self) -> &Header { &self.message.header @@ -135,7 +135,7 @@ impl MessageBuilder<'_> { unsafe { let contents = &mut self.message.contents; let contents = contents as *mut [u8] as *const UnsafeCell<[u8]>; - Builder::from_raw_parts(&*contents, &mut self.context, start) + Builder::from_raw_parts(&*contents, self.context, start) } } diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 680b828fd..95fa095ae 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -91,7 +91,7 @@ impl<'b> QuestionBuilder<'b> { //--- Interaction -impl<'b> QuestionBuilder<'b> { +impl QuestionBuilder<'_> { /// Commit this question. /// /// The builder will be consumed, and the question will be committed so diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 6a27826c3..f74418a13 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -53,9 +53,9 @@ impl<'b> RecordBuilder<'b> { let name = start.try_into().expect("Messages are at most 64KiB"); let mut b = builder.builder(start); record.rname.build_into_message(b.delegate())?; - b.append_bytes(&record.rtype.as_bytes())?; - b.append_bytes(&record.rclass.as_bytes())?; - b.append_bytes(&record.ttl.as_bytes())?; + b.append_bytes(record.rtype.as_bytes())?; + b.append_bytes(record.rclass.as_bytes())?; + b.append_bytes(record.ttl.as_bytes())?; let size = b.context().size; SizePrefixed::new(&record.rdata) .build_into_message(b.delegate())?; @@ -163,7 +163,7 @@ impl<'b> RecordBuilder<'b> { //--- Interaction -impl<'b> RecordBuilder<'b> { +impl RecordBuilder<'_> { /// Commit this record. /// /// The builder will be consumed, and the record will be committed so that From 1112224cd472ba69c8f46d0cea87ab2c9a4e2a82 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:20:25 +0100 Subject: [PATCH 104/111] [new_rdata] Use 'core::net::Ipv4Addr' --- src/new_rdata/basic.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 43d089100..2f4e24035 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -3,13 +3,9 @@ //! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). use core::fmt; - -#[cfg(feature = "std")] +use core::net::Ipv4Addr; use core::str::FromStr; -#[cfg(feature = "std")] -use std::net::Ipv4Addr; - use domain_macros::*; use crate::new_base::{ @@ -46,7 +42,6 @@ pub struct A { //--- Converting to and from 'Ipv4Addr' -#[cfg(feature = "std")] impl From for A { fn from(value: Ipv4Addr) -> Self { Self { @@ -55,7 +50,6 @@ impl From for A { } } -#[cfg(feature = "std")] impl From for Ipv4Addr { fn from(value: A) -> Self { Self::from(value.octets) @@ -64,7 +58,6 @@ impl From for Ipv4Addr { //--- Parsing from a string -#[cfg(feature = "std")] impl FromStr for A { type Err = ::Err; @@ -77,8 +70,7 @@ impl FromStr for A { impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let [a, b, c, d] = self.octets; - write!(f, "{a}.{b}.{c}.{d}") + Ipv4Addr::from(*self).fmt(f) } } From c808946addf755900864e38e7fd02c938e745346 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:25:04 +0100 Subject: [PATCH 105/111] [new_rdata] Document format in 'Txt' See: --- src/new_rdata/basic.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 2f4e24035..e880721da 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -437,6 +437,8 @@ impl BuildIntoMessage for Mx { #[repr(transparent)] pub struct Txt { /// The text strings, as concatenated [`CharStr`]s. + /// + /// The [`CharStr`]s begin with a length octet so they can be separated. content: [u8], } From 2d3e9110fb541e5cd6ee4dbde67e426718cf9c2d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 18:17:30 +0100 Subject: [PATCH 106/111] [new_base/name] Add 'LabelBuf' and support building on 'Label' --- src/new_base/name/label.rs | 221 ++++++++++++++++++++++++++++++++++++- src/new_base/name/mod.rs | 2 +- 2 files changed, 220 insertions(+), 3 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 9cb4d1d85..3c5c44239 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -1,15 +1,21 @@ //! Labels in domain names. use core::{ + borrow::{Borrow, BorrowMut}, cmp::Ordering, fmt, hash::{Hash, Hasher}, iter::FusedIterator, + ops::{Deref, DerefMut}, }; use domain_macros::AsBytes; -use crate::new_base::wire::{ParseBytes, ParseError, SplitBytes}; +use crate::new_base::{ + build::{BuildIntoMessage, BuildResult, Builder}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +}; //----------- Label ---------------------------------------------------------- @@ -48,9 +54,52 @@ impl Label { // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. unsafe { core::mem::transmute(bytes) } } + + /// Assume a mutable byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } } -//--- Parsing +//--- Parsing from DNS messages + +impl<'a> ParseMessageBytes<'a> for &'a Label { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl<'a> SplitMessageBytes<'a> for &'a Label { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Label { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + builder.append_with(self.len() + 1, |buf| { + buf[0] = self.len() as u8; + buf[1..].copy_from_slice(self.as_bytes()); + })?; + Ok(builder.commit()) + } +} + +//--- Parsing from bytes impl<'a> SplitBytes<'a> for &'a Label { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { @@ -73,6 +122,20 @@ impl<'a> ParseBytes<'a> for &'a Label { } } +//--- Building into byte strings + +impl BuildBytes for Label { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (size, data) = bytes.split_first_mut().ok_or(TruncationError)?; + let rest = self.as_bytes().build_bytes(data)?; + *size = self.len() as u8; + Ok(rest) + } +} + //--- Inspection impl Label { @@ -209,6 +272,160 @@ impl fmt::Debug for Label { } } +//----------- LabelBuf ------------------------------------------------------- + +/// A 64-byte buffer holding a [`Label`]. +#[derive(Clone)] +#[repr(C)] // make layout compatible with '[u8; 64]' +pub struct LabelBuf { + /// The size of the label, in bytes. + /// + /// This value is guaranteed to be in the range '0..64'. + size: u8, + + /// The underlying label data. + data: [u8; 63], +} + +//--- Construction + +impl LabelBuf { + /// Copy a [`Label`] into a buffer. + pub fn copy_from(label: &Label) -> Self { + let size = label.len() as u8; + let mut data = [0u8; 63]; + data[..size as usize].copy_from_slice(label.as_bytes()); + Self { size, data } + } +} + +//--- Parsing from DNS messages + +impl ParseMessageBytes<'_> for LabelBuf { + fn parse_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl SplitMessageBytes<'_> for LabelBuf { + fn split_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for LabelBuf { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + (**self).build_into_message(builder) + } +} + +//--- Parsing from byte strings + +impl ParseBytes<'_> for LabelBuf { + fn parse_bytes(bytes: &[u8]) -> Result { + <&Label>::parse_bytes(bytes).map(Self::copy_from) + } +} + +impl SplitBytes<'_> for LabelBuf { + fn split_bytes(bytes: &'_ [u8]) -> Result<(Self, &'_ [u8]), ParseError> { + <&Label>::split_bytes(bytes) + .map(|(label, rest)| (Self::copy_from(label), rest)) + } +} + +//--- Building into byte strings + +impl BuildBytes for LabelBuf { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_bytes(bytes) + } +} + +//--- Access to the underlying 'Label' + +impl Deref for LabelBuf { + type Target = Label; + + fn deref(&self) -> &Self::Target { + let label = &self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked(label) } + } +} + +impl DerefMut for LabelBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let label = &mut self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked_mut(label) } + } +} + +impl Borrow