From 37127c4b64257f2d802cd93769733ab1e4775c9d Mon Sep 17 00:00:00 2001 From: parkma99 Date: Tue, 29 Oct 2024 20:21:11 +0800 Subject: [PATCH 1/2] vec and array impl for DnsResolver Co-authored-by: Glen De Cauwsemaecker --- rama-dns/src/lib.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/rama-dns/src/lib.rs b/rama-dns/src/lib.rs index d358c098..a722d83e 100644 --- a/rama-dns/src/lib.rs +++ b/rama-dns/src/lib.rs @@ -79,6 +79,89 @@ impl>> DnsResolver for Option { } } +#[derive(Debug)] +pub struct DnsChainDomainResolveErr { + errors: Vec, +} + +impl std::fmt::Display for DnsChainDomainResolveErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "domain resolver chain resulted in errors: {:?}", + self.errors + ) + } +} + +impl std::error::Error + for DnsChainDomainResolveErr +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.errors.last().map(|e| e as &dyn std::error::Error) + } +} + +impl DnsResolver for Vec +where + R: DnsResolver + Send, + E: Send + 'static, +{ + type Error = DnsChainDomainResolveErr; + + async fn ipv4_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv4_lookup(domain.clone()).await { + Ok(ipv4s) => return Ok(ipv4s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } + + async fn ipv6_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv6_lookup(domain.clone()).await { + Ok(ipv6s) => return Ok(ipv6s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } +} + +impl DnsResolver for [R; N] +where + R: DnsResolver + Send, + E: Send + 'static, +{ + type Error = DnsChainDomainResolveErr; + + async fn ipv4_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv4_lookup(domain.clone()).await { + Ok(ipv4s) => return Ok(ipv4s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } + + async fn ipv6_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv6_lookup(domain.clone()).await { + Ok(ipv6s) => return Ok(ipv6s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } +} + macro_rules! impl_dns_resolver_either_either { ($id:ident, $($param:ident),+ $(,)?) => { impl<$($param),+> DnsResolver for ::rama_core::combinators::$id<$($param),+> From 010b7f53034a1248ab40ce6065b7ae58cc86ba4a Mon Sep 17 00:00:00 2001 From: parkma99 Date: Tue, 29 Oct 2024 22:37:31 +0800 Subject: [PATCH 2/2] move to chain.rs --- rama-dns/src/chain.rs | 127 ++++++++++++++++++++++++++++++++++++++++ rama-dns/src/lib.rs | 126 ++------------------------------------- rama-dns/src/variant.rs | 41 +++++++++++++ 3 files changed, 172 insertions(+), 122 deletions(-) create mode 100644 rama-dns/src/chain.rs create mode 100644 rama-dns/src/variant.rs diff --git a/rama-dns/src/chain.rs b/rama-dns/src/chain.rs new file mode 100644 index 00000000..05557d2c --- /dev/null +++ b/rama-dns/src/chain.rs @@ -0,0 +1,127 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use rama_net::address::Domain; + +use crate::DnsResolver; + +#[derive(Debug)] +pub struct DnsChainDomainResolveErr { + errors: Vec, +} + +impl std::fmt::Display for DnsChainDomainResolveErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "domain resolver chain resulted in errors: {:?}", + self.errors + ) + } +} + +impl std::error::Error for DnsChainDomainResolveErr { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.errors.last() + } +} + +macro_rules! dns_resolver_chain_impl { + () => { + async fn ipv4_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv4_lookup(domain.clone()).await { + Ok(ipv4s) => return Ok(ipv4s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } + + async fn ipv6_lookup(&self, domain: Domain) -> Result, Self::Error> { + let mut errors = Vec::new(); + for resolver in self { + match resolver.ipv6_lookup(domain.clone()).await { + Ok(ipv6s) => return Ok(ipv6s), + Err(err) => errors.push(err), + } + } + Err(DnsChainDomainResolveErr { errors }) + } + }; +} + +impl DnsResolver for Vec +where + R: DnsResolver + Send, + E: Send + 'static, +{ + type Error = DnsChainDomainResolveErr; + + dns_resolver_chain_impl!(); +} + +impl DnsResolver for [R; N] +where + R: DnsResolver + Send, + E: Send + 'static, +{ + type Error = DnsChainDomainResolveErr; + dns_resolver_chain_impl!(); +} + +#[cfg(test)] +mod tests { + use crate::{DenyAllDns, DnsOverwrite, InMemoryDns}; + use rama_core::combinators::Either; + use std::net::{Ipv4Addr, Ipv6Addr}; + + use super::*; + + #[tokio::test] + async fn test_empty_chain_vec() { + let v = Vec::::new(); + assert!(v + .ipv4_lookup(Domain::from_static("plabayo.tech")) + .await + .is_err()); + assert!(v + .ipv6_lookup(Domain::from_static("plabayo.tech")) + .await + .is_err()); + } + + #[tokio::test] + async fn test_empty_chain_array() { + let a: [InMemoryDns; 0] = []; + assert!(a + .ipv4_lookup(Domain::from_static("plabayo.tech")) + .await + .is_err()); + assert!(a + .ipv6_lookup(Domain::from_static("plabayo.tech")) + .await + .is_err()); + } + + // #[tokio::test] + // async fn test_chain_ok_err_ipv4() { + // let v: Vec> = vec![ + // Either::A(serde_html_form::from_str("example.com=127.0.0.1").unwrap()), + // Either::B(DenyAllDns::new()), + // ]; + // assert_eq!( + // v.ipv4_lookup(Domain::from_static("example.com")) + // .await + // .unwrap() + // .into_iter() + // .next() + // .unwrap(), + // Ipv4Addr::new(127, 0, 0, 1) + // ); + // assert!(v + // .ipv6_lookup(Domain::from_static("example.com")) + // .await + // .is_err()); + // } +} diff --git a/rama-dns/src/lib.rs b/rama-dns/src/lib.rs index a722d83e..066925e1 100644 --- a/rama-dns/src/lib.rs +++ b/rama-dns/src/lib.rs @@ -79,128 +79,6 @@ impl>> DnsResolver for Option { } } -#[derive(Debug)] -pub struct DnsChainDomainResolveErr { - errors: Vec, -} - -impl std::fmt::Display for DnsChainDomainResolveErr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "domain resolver chain resulted in errors: {:?}", - self.errors - ) - } -} - -impl std::error::Error - for DnsChainDomainResolveErr -{ - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.errors.last().map(|e| e as &dyn std::error::Error) - } -} - -impl DnsResolver for Vec -where - R: DnsResolver + Send, - E: Send + 'static, -{ - type Error = DnsChainDomainResolveErr; - - async fn ipv4_lookup(&self, domain: Domain) -> Result, Self::Error> { - let mut errors = Vec::new(); - for resolver in self { - match resolver.ipv4_lookup(domain.clone()).await { - Ok(ipv4s) => return Ok(ipv4s), - Err(err) => errors.push(err), - } - } - Err(DnsChainDomainResolveErr { errors }) - } - - async fn ipv6_lookup(&self, domain: Domain) -> Result, Self::Error> { - let mut errors = Vec::new(); - for resolver in self { - match resolver.ipv6_lookup(domain.clone()).await { - Ok(ipv6s) => return Ok(ipv6s), - Err(err) => errors.push(err), - } - } - Err(DnsChainDomainResolveErr { errors }) - } -} - -impl DnsResolver for [R; N] -where - R: DnsResolver + Send, - E: Send + 'static, -{ - type Error = DnsChainDomainResolveErr; - - async fn ipv4_lookup(&self, domain: Domain) -> Result, Self::Error> { - let mut errors = Vec::new(); - for resolver in self { - match resolver.ipv4_lookup(domain.clone()).await { - Ok(ipv4s) => return Ok(ipv4s), - Err(err) => errors.push(err), - } - } - Err(DnsChainDomainResolveErr { errors }) - } - - async fn ipv6_lookup(&self, domain: Domain) -> Result, Self::Error> { - let mut errors = Vec::new(); - for resolver in self { - match resolver.ipv6_lookup(domain.clone()).await { - Ok(ipv6s) => return Ok(ipv6s), - Err(err) => errors.push(err), - } - } - Err(DnsChainDomainResolveErr { errors }) - } -} - -macro_rules! impl_dns_resolver_either_either { - ($id:ident, $($param:ident),+ $(,)?) => { - impl<$($param),+> DnsResolver for ::rama_core::combinators::$id<$($param),+> - where - $($param: DnsResolver>),+, - { - type Error = ::rama_core::error::BoxError; - - async fn ipv4_lookup( - &self, - domain: Domain, - ) -> Result, Self::Error>{ - match self { - $( - ::rama_core::combinators::$id::$param(d) => d.ipv4_lookup(domain) - .await - .map_err(Into::into), - )+ - } - } - - async fn ipv6_lookup( - &self, - domain: Domain, - ) -> Result, Self::Error> { - match self { - $( - ::rama_core::combinators::$id::$param(d) => d.ipv6_lookup(domain) - .await - .map_err(Into::into), - )+ - } - } - } - }; -} - -rama_core::combinators::impl_either!(impl_dns_resolver_either_either); - pub mod hickory; #[doc(inline)] pub use hickory::HickoryDns; @@ -212,3 +90,7 @@ pub use in_memory::{DnsOverwrite, DomainNotMappedErr, InMemoryDns}; mod deny_all; #[doc(inline)] pub use deny_all::{DenyAllDns, DnsDeniedError}; + +pub mod chain; + +pub mod variant; diff --git a/rama-dns/src/variant.rs b/rama-dns/src/variant.rs new file mode 100644 index 00000000..250af997 --- /dev/null +++ b/rama-dns/src/variant.rs @@ -0,0 +1,41 @@ +use crate::DnsResolver; +use rama_net::address::Domain; +use std::net::{Ipv4Addr, Ipv6Addr}; +macro_rules! impl_dns_resolver_either_either { + ($id:ident, $($param:ident),+ $(,)?) => { + impl<$($param),+> DnsResolver for ::rama_core::combinators::$id<$($param),+> + where + $($param: DnsResolver>),+, + { + type Error = ::rama_core::error::BoxError; + + async fn ipv4_lookup( + &self, + domain: Domain, + ) -> Result, Self::Error>{ + match self { + $( + ::rama_core::combinators::$id::$param(d) => d.ipv4_lookup(domain) + .await + .map_err(Into::into), + )+ + } + } + + async fn ipv6_lookup( + &self, + domain: Domain, + ) -> Result, Self::Error> { + match self { + $( + ::rama_core::combinators::$id::$param(d) => d.ipv6_lookup(domain) + .await + .map_err(Into::into), + )+ + } + } + } + }; +} + +rama_core::combinators::impl_either!(impl_dns_resolver_either_either);