From 48c6ade7a7eb7c2046c78c6cedc72ae34082e472 Mon Sep 17 00:00:00 2001 From: sinderella Date: Sat, 6 Aug 2022 10:41:52 +0100 Subject: [PATCH] switch to tokio and handle panics --- .gitignore | 6 + Cargo.lock | 74 +++++++++++- Cargo.toml | 3 +- src/main.rs | 335 ++++++++++++++++++++++++++++------------------------ 4 files changed, 259 insertions(+), 159 deletions(-) diff --git a/.gitignore b/.gitignore index ea8c4bf..616a294 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ /target +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets diff --git a/Cargo.lock b/Cargo.lock index 7e7463b..28a39b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,13 +104,14 @@ checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" [[package]] name = "dns-rebinder" -version = "0.2.0" +version = "0.2.1" dependencies = [ "clap", "env_logger", "hex", "log", "rand", + "tokio", "trust-dns-proto", ] @@ -272,6 +273,16 @@ version = "0.2.126" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +[[package]] +name = "lock_api" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.17" @@ -327,6 +338,29 @@ version = "6.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "648001efe5d5c0102d8cea768e348da85d90af8ba91f0bea908f157951493cd4" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -423,6 +457,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.6.0" @@ -440,6 +483,21 @@ version = "0.6.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.7" @@ -545,11 +603,25 @@ dependencies = [ "mio", "num_cpus", "once_cell", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", + "tokio-macros", "winapi", ] +[[package]] +name = "tokio-macros" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "trust-dns-proto" version = "0.21.2" diff --git a/Cargo.toml b/Cargo.toml index 9033994..991d805 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "dns-rebinder" authors = ["sinderella"] -version = "0.2.0" +version = "0.2.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -12,4 +12,5 @@ env_logger = "0.9.0" hex = "0.4.3" log = "0.4.17" rand = "0.8.5" +tokio = { version = "1", features = ["full"] } trust-dns-proto = "0.21.2" diff --git a/src/main.rs b/src/main.rs index e7fba02..5cac166 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,10 @@ use env_logger::Builder; use log::LevelFilter; use log::{error, info, warn}; use rand::Rng; -use std::net::{Ipv4Addr, SocketAddr, UdpSocket}; +use std::net::{Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use tokio::net::UdpSocket; + use trust_dns_proto::op::{Message, Query}; use trust_dns_proto::rr::rdata::SOA; use trust_dns_proto::rr::{Name, RData, Record, RecordType}; @@ -26,174 +29,173 @@ mod cli; type Error = Box; type Result = std::result::Result; -struct Rebinder { - domain: String, - ns_records: Option>, - ns_public_ip: Option, -} - -impl Rebinder { - pub(crate) fn new( - domain: String, - ns_records: Option>, - ns_public_ip: Option, - ) -> Rebinder { - Rebinder { - domain, - ns_records, - ns_public_ip, - } +fn process_query( + domain: Arc, + ns_records: Arc>>, + ns_public_ip: Arc>, + query: &Query, + addr: SocketAddr, + header_id: u16, +) -> Result> { + let mut records: Vec = Vec::new(); + let qname = match query.name().to_ascii().strip_suffix('.') { + Some(it) => it, + None => return Ok(records), } + .to_ascii_lowercase(); - fn process_query( - &mut self, - query: &Query, - addr: SocketAddr, - header_id: u16, - ) -> Option> { - let mut records: Vec = Vec::new(); - let qname = query - .name() - .to_ascii() - .strip_suffix('.')? - .to_ascii_lowercase(); - - // only support the root domain that it owns - if !qname.ends_with(&self.domain) { - return None; - } + // only support the root domain that it owns + if !qname.ends_with(domain.as_ref()) { + return Ok(records); + } - match query.query_type() { - RecordType::A => { - // accepting ...root.domain - let parts: Vec<&str> = qname.split('.').collect(); - - if parts[0].starts_with("ns") { - if let Some(ns_public_ip) = self.ns_public_ip { - records.push(Record::from_rdata( - query.name().clone(), - 600, - RData::A(ns_public_ip), - )); - return Some(records); - } + match query.query_type() { + RecordType::A => { + // accepting ...root.domain + let parts: Vec<&str> = qname.split('.').collect(); + + if parts[0].starts_with("ns") { + if let Some(ns_public_ip) = ns_public_ip.as_ref() { + records.push(Record::from_rdata( + query.name().clone(), + 600, + RData::A(*ns_public_ip), + )); + return Ok(records); } + } - let primary = u32::from_str_radix(parts[0], 16).unwrap(); - let secondary = u32::from_str_radix(parts[1], 16).unwrap(); + if qname.eq(domain.as_ref()) { + return Ok(records); + } - info!( - "{:?}[{:?}] - parsed targets: primary: {:#?}, secondary: {:#?}", - addr, - header_id, - Ipv4Addr::from(primary), - Ipv4Addr::from(secondary) - ); + if qname.matches('.').count() != domain.matches('.').count() + 2 { + return Ok(records); + } - if primary.eq(&secondary) { - warn!( + let loopback = u32::from_be_bytes(Ipv4Addr::new(127, 0, 0, 1).octets()); + + let primary = match u32::from_str_radix(parts[0], 16) { + Ok(decoded) => decoded, + Err(_) => return Ok(records), + }; + let secondary = match u32::from_str_radix(parts[1], 16) { + Ok(decoded) => decoded, + Err(_) => return Ok(records), + }; + + info!( + "{:?}[{:?}] - parsed targets: primary: {:#?}, secondary: {:#?}", + addr, + header_id, + Ipv4Addr::from(primary), + Ipv4Addr::from(secondary) + ); + + if primary.eq(&secondary) && primary.ne(&loopback) { + warn!( "{:?}[{:?}] - primary and secondary labels are indentical, possibly an abuse", addr, header_id ); - return None; - } + return Ok(records); + } - let mut rng = rand::thread_rng(); - let is_primary = rng.gen_range(0..2) % 2 == 0; + let mut rng = rand::thread_rng(); + let is_primary = rng.gen_range(0..2) % 2 == 0; - records.push(Record::from_rdata( - query.name().clone(), - 1, - RData::A(match is_primary { - true => Ipv4Addr::from(primary), - false => Ipv4Addr::from(secondary), - }), - )); + records.push(Record::from_rdata( + query.name().clone(), + 1, + RData::A(match is_primary { + true => Ipv4Addr::from(primary), + false => Ipv4Addr::from(secondary), + }), + )); - return Some(records); - } - RecordType::NS => match &self.ns_records { - Some(ns_records) => { - for ns_record in ns_records { - records.push(Record::from_rdata( - Name::from_ascii(qname.clone()).unwrap(), - 600, - RData::NS(ns_record.clone()), - )); - } - return Some(records); + return Ok(records); + } + RecordType::NS => match ns_records.as_ref() { + Some(ns_records) => { + for ns_record in ns_records { + records.push(Record::from_rdata( + Name::from_ascii(qname.clone()).unwrap(), + 600, + RData::NS(ns_record.clone()), + )); } - None => {} - }, - RecordType::SOA => { - let ns_record = self.ns_records.as_ref().unwrap().first()?; - - let soa = SOA::new( - ns_record.clone(), - Name::from_ascii("").unwrap(), - 1, - 86400, - 7200, - 4000000, - 600, - ); - - records.push(Record::from_rdata( - Name::from_ascii(qname).unwrap(), - 600, - RData::SOA(soa), - )); - - return Some(records); + return Ok(records); } - RecordType::AAAA => return None, - RecordType::ANY => return None, - RecordType::AXFR => return None, - RecordType::CNAME => return None, - _ => return None, + None => {} + }, + RecordType::SOA => { + let ns_record = ns_records.as_deref().unwrap().first().unwrap(); + + let soa = SOA::new( + ns_record.clone(), + Name::from_ascii("").unwrap(), + 1, + 86400, + 7200, + 4000000, + 600, + ); + + records.push(Record::from_rdata( + Name::from_ascii(qname).unwrap(), + 600, + RData::SOA(soa), + )); + + return Ok(records); } - - None + RecordType::AAAA => return Ok(records), + RecordType::ANY => return Ok(records), + RecordType::AXFR => return Ok(records), + RecordType::CNAME => return Ok(records), + _ => return Ok(records), } - pub(crate) fn handle_query(&mut self, socket: &UdpSocket) -> Result<()> { - let mut buffer = [0_u8; 512]; - let (len, addr) = socket.recv_from(&mut buffer).expect("receive failed"); - let request = Message::from_vec(&buffer[0..len]).expect("failed parse of request"); - - let mut message = Message::new(); - message.set_id(request.id()); - message.set_recursion_desired(request.recursion_desired()); - message.set_recursion_available(false); - - // unlikely, see https://stackoverflow.com/a/4083071 - if request.query_count() != 1 { - let bytes = message.to_vec().unwrap(); - socket.send_to(&bytes, addr).expect("send failed"); - return Ok(()); - } - - if let Some(query) = request.queries().first() { - info!("{:?}[{:?}] - {:?}", addr, request.id(), query); - message.add_query(query.clone()); - if let Some(records) = self.process_query(query, addr, request.id()) { - info!("{:?}[{:?}] - {:?}", addr, request.id(), records); - message.add_answers(records); - } - } else { - let bytes = message.to_vec().unwrap(); - socket.send_to(&bytes, addr).expect("send failed"); - return Ok(()); - } + Ok(records) +} +pub(crate) async fn handle_connection( + socket: Arc, + domain: Arc, + ns_records: Arc>>, + ns_public_ip: Arc>, +) -> Result<()> { + let mut buffer = [0_u8; 512]; + let (len, addr) = socket.recv_from(&mut buffer).await.expect("receive failed"); + let request = Message::from_vec(&buffer[0..len]).expect("failed parse of request"); + + let mut message = Message::new(); + message.set_id(request.id()); + message.set_recursion_desired(request.recursion_desired()); + message.set_recursion_available(false); + + // unlikely, see https://stackoverflow.com/a/4083071 + if request.query_count() != 1 { let bytes = message.to_vec().unwrap(); - socket.send_to(&bytes, addr).expect("send failed"); + socket.send_to(&bytes, addr).await.expect("send failed"); + return Ok(()); + } - Ok(()) + if let Some(query) = request.queries().first() { + info!("{:?}[{:?}] - {:?}", addr, request.id(), query); + message.add_query(query.clone()); + let records = process_query(domain, ns_records, ns_public_ip, query, addr, request.id())?; + info!("{:?}[{:?}] - {:?}", addr, request.id(), records); + message.add_answers(records); } + + let bytes = message.to_vec().unwrap(); + socket.send_to(&bytes, addr).await.expect("send failed"); + + Ok(()) } -fn main() -> Result<()> { +#[tokio::main()] +async fn main() -> Result<()> { Builder::new().filter_level(LevelFilter::Debug).init(); let cli = Cli::parse(); @@ -223,23 +225,42 @@ fn main() -> Result<()> { domain, port, network_interface, ns_records ); - let socket = UdpSocket::bind((network_interface, port))?; + let socket = UdpSocket::bind((network_interface, port)) + .await + .expect("couldn't bind to address"); info!("started listening on port: {:?}", port); - let mut rebinder = Rebinder::new(domain, ns_records, ns_public_ip); - - let handler = std::thread::Builder::new() - .name("rebinder:server".to_string()) - .spawn(move || loop { - match rebinder.handle_query(&socket) { + let s = Arc::new(socket); + let d = Arc::new(domain); + let nsr = Arc::new(ns_records); + let npip = Arc::new(ns_public_ip); + + loop { + let sock_param = Arc::clone(&s); + let domain_param = Arc::clone(&d); + let ns_records_param = Arc::clone(&nsr); + let ns_public_ip_param = Arc::clone(&npip); + + let handler = tokio::spawn(async move { + match handle_connection( + sock_param, + domain_param, + ns_records_param, + ns_public_ip_param, + ) + .await + { Ok(_) => {} - Err(e) => error!("An error occurred: {}", e), + Err(e) => error!("{}", e), } + }); - std::thread::yield_now(); - })?; - - handler.join().unwrap(); + match handler.await { + Ok(_) => {} + Err(e) => error!("{}", e), + } + } + #[allow(unreachable_code)] Ok(()) }