Skip to content

Commit

Permalink
optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
zyj committed Feb 6, 2025
1 parent 825b21f commit cd4b76d
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 158 deletions.
11 changes: 5 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
authors = ["zyj"]
edition = "2021"
name = "yadns"
version = "0.6.5"
version = "0.6.6"

[features]
default = ["default-doh-rustls", "default-doh3-rustls"]
default = ["default-doh-rustls", "default-doh3-rustls", "logging"]

default-doh-rustls = [
"default-dot-rustls",
Expand All @@ -28,28 +28,27 @@ default-tcp_udp = []
dns-over-https = ["hickory-resolver/dns-over-https"]
dns-over-tls = ["hickory-resolver/dns-over-tls"]
dns-over-h3 = ["hickory-resolver/dns-over-h3"]
logging = ["dep:env_logger"]

[dependencies]
async-http-proxy = {version = "1", features = ["runtime-tokio", "basic-auth"]}
async-recursion = "1"
async-trait = "0.1"
clap = {version = "4", features = ["derive"]}
crossbeam-channel = "0.5"
env_logger = {version = "0.11", optional = true}
fast-socks5 = "0.10"
futures = {version = "0.3", default-features = false, features = ["executor"]}
hickory-proto = "0.24"
hickory-resolver = {version = "0.24", default-features = false, features = ["tokio-runtime"]}
hickory-server = "0.24"
ipnet = "2"
iprange = "0.6"
once_cell = "1"
log = "0.4"
publicsuffix = "2"
regex = "1"
serde = "1"
serde_derive = "1"
slog = "2"
slog-async = "2"
slog-term = "2"
thiserror = "1"
tokio = "1"
toml = "0.8"
Expand Down
3 changes: 3 additions & 0 deletions examples/template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# Root privilege may be required if you specify a port below 1024.
bind = "127.0.0.1:5300" # the address that ya-dns listens on

# Specify the log level
log = "info" # error warn info debug trace

# Configuration for the Resolver
[resolver_opts]
# Specify the timeout for a request. Defaults to 5 seconds
Expand Down
27 changes: 24 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ pub enum ConfigError {
NoUpstream,
#[error("{0}:{1}")]
InvalidAddress(std::net::AddrParseError, String),
#[cfg(any(feature = "dns-over-tls", feature = "dns-over-https", feature = "dns-over-h3"))]
#[cfg(any(
feature = "dns-over-tls",
feature = "dns-over-https",
feature = "dns-over-h3"
))]
#[error("tls-host is missing")]
NoTlsHost,
}

#[derive(Debug)]
pub struct Config {
pub bind: SocketAddr,
#[cfg(feature = "logging")]
pub log_level: log::LevelFilter,
pub default_upstreams: Vec<String>,
pub resolver_opts: ResolverOpts,
pub upstreams: HashMap<String, Upstream>,
Expand All @@ -37,9 +43,11 @@ pub struct Config {
pub response_rules: Vec<ResponseRule>,
}

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct ConfigBuilder {
bind: SocketAddr,
log: Option<String>,
resolver_opts: Option<ResolverOptsConfig>,
upstreams: HashMap<String, UpstreamConfig>,
domains: Option<HashMap<String, DomainsConf>>,
Expand Down Expand Up @@ -133,6 +141,12 @@ impl ConfigBuilder {

Ok(Config {
bind: self.bind,
#[cfg(feature = "logging")]
log_level: self
.log
.as_ref()
.map(|s| log::LevelFilter::from_str(s).unwrap_or(log::LevelFilter::Info))
.unwrap_or(log::LevelFilter::Info),
default_upstreams,
resolver_opts,
upstreams,
Expand Down Expand Up @@ -199,7 +213,11 @@ pub struct UpstreamConfig {
address: Vec<String>,
network: NetworkType,
proxy: Option<String>,
#[cfg(any(feature = "dns-over-tls", feature = "dns-over-https", feature = "dns-over-h3"))]
#[cfg(any(
feature = "dns-over-tls",
feature = "dns-over-https",
feature = "dns-over-h3"
))]
#[serde(rename = "tls-host")]
tls_host: Option<String>,
#[serde(default = "UpstreamConfig::default_default")]
Expand Down Expand Up @@ -361,7 +379,10 @@ impl DomainsConf {
};
regex_set.push(dm);
} else {
let line1 = line.trim_start_matches("full:").trim_start_matches("domain:").trim_start_matches(".");
let line1 = line
.trim_start_matches("full:")
.trim_start_matches("domain:")
.trim_start_matches(".");
let dm = match line1.find(":@") {
Some(index) => line1[..index].to_string(),
None => String::from(line1),
Expand Down
15 changes: 7 additions & 8 deletions src/filter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::{
config::{RequestRule, ResponseRule, RuleAction},
handler_config::HandlerConfig,
logger::stderr,
handler_config::HandlerConfig
};
use hickory_proto::{op::LowerQuery, rr::RecordType};
use hickory_resolver::lookup::Lookup;
use slog::debug;
use log::debug;

pub fn check_response(
cfg: &HandlerConfig,
Expand Down Expand Up @@ -72,7 +71,7 @@ pub fn check_response(
.unwrap_or(RuleAction::Accept)
}

pub fn resolvers<'a>(cfg: &'a HandlerConfig, query: &LowerQuery) -> Vec<&'a str> {
pub fn resolvers(cfg: &HandlerConfig, query: &LowerQuery) -> Vec<String> {
let name = query.name().to_string();

let check_type = |rule: &RequestRule| {
Expand All @@ -88,12 +87,12 @@ pub fn resolvers<'a>(cfg: &'a HandlerConfig, query: &LowerQuery) -> Vec<&'a str>
.find(|r| check_domains(cfg, &name, &r.domains) && check_type(r));

if let Some(rule) = rule {
debug!(stderr(), "Query {} matches rule {:?}", name, rule);
rule.upstreams.iter().map(String::as_str).collect()
debug!("Query {} matches rule {:?}", name, rule);
rule.upstreams.iter().map(String::clone).collect()
} else {
debug!(stderr(), "No rule matches for {}. Use defaults.", name);
debug!("No rule matches for {}. Use defaults.", name);
// If no rule matches, use defaults
cfg.defaults.iter().map(String::as_str).collect()
cfg.defaults.iter().map(String::clone).collect()
}
}

Expand Down
187 changes: 87 additions & 100 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::time::Duration;

use crate::{config::RuleAction, filter, handler_config::HandlerConfig, logger::stderr};
use crate::{config::RuleAction, filter, handler_config::HandlerConfig};
use crossbeam_channel::bounded;
use hickory_proto::op::LowerQuery;
use hickory_resolver::{
error::{ResolveError, ResolveErrorKind},
lookup::Lookup,
Expand All @@ -11,119 +12,105 @@ use hickory_server::{
proto::op::{Header, MessageType, OpCode, ResponseCode},
server::{Request, RequestHandler, ResponseHandler, ResponseInfo},
};
use once_cell::sync::OnceCell;
use slog::debug;
use log::debug;
use tokio::{runtime::Runtime, time::timeout};

static HANDLER_CONFIG: OnceCell<HandlerConfig> = OnceCell::new();

fn handler_config() -> &'static HandlerConfig {
HANDLER_CONFIG
.get()
.expect("HandlerConfig is not initialized")
}

#[derive(Clone, Debug)]
struct RequestResult {
lookup: Option<Lookup>,
code: ResponseCode,
}

/// Handle request, returning ResponseInfo if response was successfully sent, or an error.
async fn do_handle_request(request: &Request) -> Result<RequestResult, ResolveError> {
debug!(
stderr(),
"DNS requests are forwarded to [{}].",
request.query()
);
// make sure the request is a query and the message type is a query
if request.op_code() != OpCode::Query || request.message_type() != MessageType::Query {
return Ok(RequestResult {
lookup: None,
code: ResponseCode::Refused,
});
}
do_handle_request_default(request).await
}

/// Handle requests for anything else (NXDOMAIN)
async fn do_handle_request_default(request: &Request) -> Result<RequestResult, ResolveError> {
//self.counter.fetch_add(1, Ordering::SeqCst);
let resolvers = filter::resolvers(handler_config(), request.query());
let resolvers_len = resolvers.len();
let (tx, rx) = bounded(resolvers_len);
let rt = Runtime::new().unwrap();
resolvers
.iter()
.map(|name| {
(
handler_config().resolvers.get(*name).cloned().unwrap(),
*name,
request.query().name().to_string(),
request.query().query_type(),
)
})
.for_each(|(rs, name, domain, query_type)| {
let tx1 = tx.clone();
rt.spawn(async move {
let res = timeout(Duration::from_secs(1), rs.resolve(&domain, query_type)).await;
let lookup = match res {
Ok(lookup) => lookup,
Err(_) => Err(ResolveErrorKind::Timeout.into()),
};
match lookup {
Ok(lookup) => {
let _ = tx1.try_send(Some((lookup, name, domain)));
}
Err(_) => {
let _ = tx1.try_send(None);
}
}
});
});
let mut lookup_result = None;
for _ in 0..resolvers_len {
let lookup = rx.recv().unwrap();
match lookup {
Some((lookup, name, domain)) => {
match filter::check_response(handler_config(), &domain, name, &lookup) {
RuleAction::Accept => {
debug!(stderr(), "Use result from {}", name);
lookup_result = Some(lookup);
break;
}
RuleAction::Drop => (),
}
}
None => {}
}
}
rt.shutdown_background();
drop(tx);
match lookup_result {
Some(lookup) => Ok(RequestResult {
lookup: Some(lookup),
code: ResponseCode::NoError,
}),
None => Ok(RequestResult {
lookup: None,
code: ResponseCode::NXDomain,
}),
}
}

/// DNS Request Handler
#[derive(Clone, Debug)]
pub struct Handler {
//pub counter: Arc<AtomicU64>,
config: HandlerConfig,
}
impl Handler {
/// Create handler from app config.
pub fn new(cfg: HandlerConfig) -> Self {
match HANDLER_CONFIG.set(cfg) {
_ => Handler {
// counter: Arc::new(AtomicU64::new(0)),
},
Handler { config: cfg }
}

/// Handle request, returning ResponseInfo if response was successfully sent, or an error.
async fn do_handle_request(&self, request: &Request) -> Result<RequestResult, ResolveError> {
debug!("DNS requests are forwarded to [{}].", request.query());
// make sure the request is a query and the message type is a query
if request.op_code() != OpCode::Query || request.message_type() != MessageType::Query {
return Ok(RequestResult {
lookup: None,
code: ResponseCode::Refused,
});
}
self.lookup(request.query()).await
}

/// Lookup for anything else (NXDOMAIN)
async fn lookup(&self, query: &LowerQuery) -> Result<RequestResult, ResolveError> {
//self.counter.fetch_add(1, Ordering::SeqCst);
let config = &self.config;
let resolvers = filter::resolvers(config, query);
let resolvers_len = resolvers.len();
let (tx, rx) = bounded(resolvers_len);
let rt = Runtime::new().unwrap();
resolvers
.into_iter()
.map(|name| {
(
config.resolvers.get(&name).cloned().unwrap(),
name,
query.name().to_string(),
query.query_type(),
)
})
.for_each(|(rs, name, domain, query_type)| {
let tx1 = tx.clone();
rt.spawn(async move {
let res =
timeout(Duration::from_secs(1), rs.resolve(&domain, query_type)).await;
let lookup = match res {
Ok(lookup) => lookup,
Err(_) => Err(ResolveErrorKind::Timeout.into()),
};
match lookup {
Ok(lookup) => {
let _ = tx1.try_send(Some((lookup, name, domain)));
}
Err(_) => {
let _ = tx1.try_send(None);
}
}
});
});
let mut lookup_result = None;
for _ in 0..resolvers_len {
let lookup = rx.recv().unwrap();
match lookup {
Some((lookup, name, domain)) => {
match filter::check_response(config, &domain, &name, &lookup) {
RuleAction::Accept => {
debug!("Use result from {}", name);
lookup_result = Some(lookup);
break;
}
RuleAction::Drop => (),
}
}
None => {}
}
}
rt.shutdown_background();
drop(tx);
match lookup_result {
Some(lookup) => Ok(RequestResult {
lookup: Some(lookup),
code: ResponseCode::NoError,
}),
None => Ok(RequestResult {
lookup: None,
code: ResponseCode::NXDomain,
}),
}
}
}
Expand All @@ -136,10 +123,10 @@ impl RequestHandler for Handler {
mut response: R,
) -> ResponseInfo {
// try to handle request
let result = match do_handle_request(request).await {
let result = match self.do_handle_request(request).await {
Ok(info) => info,
Err(e) => {
debug!(stderr(), "Error in RequestHandler:{:#?}", e);
debug!("Error in RequestHandler:{:#?}", e);
RequestResult {
lookup: None,
code: ResponseCode::ServFail,
Expand Down
Loading

0 comments on commit cd4b76d

Please sign in to comment.