diff --git a/src/dns.rs b/src/dns.rs new file mode 100644 index 00000000000..c8eb2ea9897 --- /dev/null +++ b/src/dns.rs @@ -0,0 +1,305 @@ +use core::mem; + +use anyhow::*; +use log::*; + +use esp_idf_sys::c_types::*; +use esp_idf_sys::*; + +use crate::private::cstr::{CStr, CString}; +use crate::task::{TaskConfig, TaskHandle}; +use crate::lwip; + +/// 0.0.0.0 +const IPADDR_ANY: u32 = 0x00000000; + +const DNS_PORT: c_ushort = 53; +const DNS_MAX_LEN: usize = 256; + +const OPCODE_MASK: u16 = 0x7800; +const QR_FLAG: u16 = 1 << 7; +const QD_TYPE_A: u16 = 0x0001; +const ANS_TTL_SEC: u32 = 300; + +pub struct CaptivePortalDns { + task_handle: Option, +} + +impl CaptivePortalDns { + pub fn new() -> Self { + CaptivePortalDns { task_handle: None } + } + + pub fn start(&mut self) -> Result<()> { + if self.task_handle.is_some() { + bail!("dns server is already running"); + } + + let handle = TaskConfig::default() + .priority(5) + .spawn("dns_server", dns_server_task)?; + + self.task_handle = Some(handle); + Ok(()) + } + + pub fn stop(&mut self) -> Result<()> { + if let Some(handle) = self.task_handle.take() { + handle.stop(); + Ok(()) + } else { + Err(anyhow!("dns task already stopped or was never started")) + } + } +} + +impl Drop for CaptivePortalDns { + fn drop(&mut self) { + if self.task_handle.is_some() { + self.stop().unwrap(); + } + } +} + +/// DNS Header Packet +#[repr(C, packed)] +struct DnsHeader { + id: u16, + flags: u16, + qd_count: u16, + an_count: u16, + ns_count: u16, + ar_count: u16, +} + +/// DNS Question Packet +#[repr(C)] +struct DnsQuestion { + typ: u16, + class: u16, +} + +/// DNS Answer Packet +#[repr(C, packed)] +struct DnsAnswer { + name_ptr: u16, + typ: u16, + class: u16, + ttl: u32, + addr_len: u16, + ip_addr: u32, +} + +fn parse_dns_name(raw_name: *mut u8, parsed_name: &mut [u8]) -> *mut u8 { + let mut label = raw_name; + let parsed_name_max_len = parsed_name.len(); + let mut name_itr = parsed_name.iter_mut(); + let mut name_len: usize = 0; + + loop { + let sub_name_len = unsafe { *label as c_int }; + // (len + 1) since we are adding a '.' + name_len += (sub_name_len + 1) as usize; + if name_len > parsed_name_max_len { + return core::ptr::null_mut(); + } + + // Copy the sub name that follows the the label + for i in 0..sub_name_len { + let ptr = name_itr.next().unwrap(); + *ptr = unsafe { *label.offset((i + 1) as isize) }; + } + *name_itr.next().unwrap() = '.' as u8; + label = unsafe { label.offset((sub_name_len + 1) as isize) }; + + if unsafe { *label == 0 } { + break; + } + } + + // Terminate the final string, replacing the last '.' + parsed_name[name_len - 1] = '\0' as u8; + // Return pointer to first char after the name + return unsafe { label.offset(1) }; +} + +fn parse_dns_request( + req: &mut [u8], + req_len: usize, + dns_reply: &mut [u8], + dns_reply_max_len: usize, +) -> Option { + if req_len > dns_reply_max_len { + return None; + } + + // Prepare the reply + dns_reply.fill(0); + (&mut dns_reply[0..req_len]).copy_from_slice(&req[0..req_len]); + + let header_len = mem::size_of::(); + let (header_bytes, rest) = dns_reply.split_at_mut(header_len); + + // Endianess of NW packet different from chip + let header = unsafe { + header_bytes + .as_mut_ptr() + .cast::() + .as_mut() + .unwrap() + }; + + debug!( + "DNS query with header id: 0x{:X}, flags: 0x{:X}, qd_count: {}", + ntohs(header.id), + ntohs(header.flags), + ntohs(header.qd_count) + ); + + // Not a standard query + if (header.flags & OPCODE_MASK) != 0 { + return None; + } + + // Set question response flag + header.flags |= QR_FLAG; + + let qd_count = ntohs(header.qd_count); + header.an_count = htons(qd_count); + + let reply_len = qd_count as usize * mem::size_of::() + req_len; + if reply_len > dns_reply_max_len { + return None; + } + + // Pointer to current answer and question + let (questions, answers) = rest.split_at_mut(req_len - header_len); + let cur_qd_ptr = questions.as_mut_ptr(); + let mut cur_ans_ptr = answers.as_mut_ptr(); + let mut name: [u8; 128] = [0; 128]; + + // Respond to all questions with the ESP32's IP address + for i in 0..qd_count { + debug!("answering question {}", i); + let name_end_ptr = parse_dns_name(cur_qd_ptr, &mut name); + if name_end_ptr.is_null() { + error!("failed to parse DNS question: {:?}", unsafe { + CStr::from_ptr(cur_qd_ptr as _) + }); + return None; + } + + let question = unsafe { name_end_ptr.cast::().as_mut().unwrap() }; + let qd_type = ntohs(question.typ); + let qd_class = ntohs(question.class); + + info!( + "received type: {} | class: {} | question for: {:?}", + qd_type, + qd_class, + unsafe { CStr::from_ptr(name.as_ptr() as _) } + ); + + if qd_type == QD_TYPE_A { + let answer = unsafe { cur_ans_ptr.cast::().as_mut().unwrap() }; + + let ptr_offset = unsafe { cur_qd_ptr.offset_from(dns_reply.as_ptr()) }; + answer.name_ptr = htons((0xC000 | ptr_offset) as u16); + answer.typ = htons(qd_type); + answer.class = htons(qd_class); + answer.ttl = htonl(ANS_TTL_SEC); + + let mut ip_info = esp_netif_ip_info_t::default(); + let c_if_key = CString::new("WIFI_AP_DEF").unwrap(); + unsafe { + esp_netif_get_ip_info( + esp_netif_get_handle_from_ifkey(c_if_key.as_ptr()), + &mut ip_info, + ) + }; + + info!( + "answer with PTR offset: 0x{:X} (0x{:X}) and IP 0x{:X}", + ntohs(answer.name_ptr), + ptr_offset, + ip_info.ip.addr, + ); + + answer.addr_len = htons(mem::size_of_val(&ip_info.ip.addr) as u16); + answer.ip_addr = ip_info.ip.addr; + + cur_ans_ptr = unsafe { cur_ans_ptr.offset(mem::size_of::() as isize) }; + } + } + + Some(reply_len) +} + +fn dns_server_task() -> Result<()> { + use lwip::*; + + let mut rx_buffer = [0; 128]; + + loop { + let dest_addr = sockaddr_in { + sin_addr: in_addr { + s_addr: htonl(IPADDR_ANY), + }, + sin_family: AF_INET as u8, + sin_port: htons(DNS_PORT), + ..Default::default() + }; + + let mut sock = Socket::open(AddressFamily::Ipv4, SocketType::Dgram, Protocol::Ip) + .context("creating socket")?; + info!("socket created"); + + sock.bind(&dest_addr as *const sockaddr_in as _) + .context("binding socket")?; + info!("socket bound, port {}", DNS_PORT); + + loop { + info!("waiting for data"); + + let (len, source_addr) = sock.recv_from(&mut rx_buffer).context("recv_from")?; + + // Null-terminate whatever we received + rx_buffer[len] = 0; + + let mut reply = [0; DNS_MAX_LEN]; + let reply_len = parse_dns_request(&mut rx_buffer, len, &mut reply, DNS_MAX_LEN); + + info!( + "received {} bytes from {} | DNS reply with len: {:?}", + len, source_addr, reply_len + ); + + if let Some(reply_len) = reply_len { + sock.send_to(&mut reply[0..reply_len], source_addr) + .context("send_to")?; + } else { + error!("failed to prepare a DNS reply"); + } + } + } +} + +/// host to network byte order +fn htonl(n: u32) -> u32 { + n.to_be() +} + +/// host to network byte order +fn htons(n: u16) -> u16 { + n.to_be() +} + +/// network to host byte order +fn ntohl(n: u32) -> u32 { + u32::from_be(n) +} + +/// network to host byte order +fn ntohs(n: u16) -> u16 { + u16::from_be(n) +} diff --git a/src/lib.rs b/src/lib.rs index 235f66d9f67..3f48efd71d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,12 @@ #![cfg_attr(feature = "experimental", feature(generic_associated_types))] // for http, http::client, http::server, ota #![feature(const_btree_new)] -#[cfg(any(feature = "alloc"))] +#[cfg(feature = "alloc")] #[macro_use] extern crate alloc; +#[cfg(feature = "alloc")] +pub mod dns; #[cfg(feature = "alloc")] #[cfg(any( all(esp32, esp_idf_eth_use_esp32_emac), @@ -25,6 +27,8 @@ pub mod httpd; #[cfg(feature = "alloc")] // TODO: Ideally should not need "alloc" (also for performance reasons) pub mod log; +pub mod lwip; +pub mod misc; #[cfg(esp_idf_config_lwip_ipv4_napt)] pub mod napt; pub mod netif; @@ -43,6 +47,9 @@ pub mod nvs_storage; pub mod ota; pub mod ping; pub mod sysloop; +#[cfg(feature = "alloc")] +pub mod task; +pub mod time; #[cfg(feature = "alloc")] // TODO: Expose a subset which does not require "alloc" pub mod wifi; diff --git a/src/lwip.rs b/src/lwip.rs new file mode 100644 index 00000000000..622705546b8 --- /dev/null +++ b/src/lwip.rs @@ -0,0 +1,216 @@ +use core::cmp; +use core::fmt; +use core::mem; + +use esp_idf_sys::c_types::*; +use esp_idf_sys::*; + +use crate::private::cstr; + +pub type Result = core::result::Result; + +pub enum AddressFamily { + Ipv4, + Ipv6, +} + +impl Into for AddressFamily { + fn into(self) -> c_int { + match self { + AddressFamily::Ipv4 => AF_INET as _, + AddressFamily::Ipv6 => AF_INET6 as _, + } + } +} + +pub enum SocketType { + Stream, + Dgram, + Raw, +} + +impl Into for SocketType { + fn into(self) -> c_int { + match self { + SocketType::Stream => SOCK_STREAM as _, + SocketType::Dgram => SOCK_DGRAM as _, + SocketType::Raw => SOCK_RAW as _, + } + } +} + +pub enum Protocol { + Ip, + Icmp, + Tcp, + Udp, + Ipv6, + Icmpv6, + UdpLite, + Raw, +} + +impl Into for Protocol { + fn into(self) -> c_int { + match self { + Protocol::Ip => IPPROTO_IP as _, + Protocol::Icmp => IPPROTO_ICMP as _, + Protocol::Tcp => IPPROTO_TCP as _, + Protocol::Udp => IPPROTO_UDP as _, + Protocol::Ipv6 => IPPROTO_IPV6 as _, + Protocol::Icmpv6 => IPPROTO_ICMPV6 as _, + Protocol::UdpLite => IPPROTO_UDPLITE as _, + Protocol::Raw => IPPROTO_RAW as _, + } + } +} + +pub struct SocketAddrV4 { + inner: sockaddr_in, +} + +impl fmt::Display for SocketAddrV4 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + let mut addr_str = [0u8; INET_ADDRSTRLEN as usize]; + + unsafe { + ip4addr_ntoa_r( + &self.inner.sin_addr as *const _ as _, + addr_str.as_mut_ptr() as _, + (addr_str.len() - 1) as _, + ) + }; + + let s = cstr::from_cstr(&addr_str); + + write!(f, "{}:{}", s, self.inner.sin_port) + } +} + +pub struct SocketAddrV6 { + inner: sockaddr_in6, +} + +impl fmt::Display for SocketAddrV6 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + let mut addr_str = [0u8; INET6_ADDRSTRLEN as usize]; + + unsafe { + ip6addr_ntoa_r( + &self.inner.sin6_addr as *const _ as _, + addr_str.as_mut_ptr() as _, + (addr_str.len() - 1) as _, + ) + }; + + let s = cstr::from_cstr(&addr_str); + + write!(f, "[{}]:{}", s, self.inner.sin6_port) + } +} + +pub enum SocketAddr { + V4(SocketAddrV4), + V6(SocketAddrV6), +} + +impl SocketAddr { + pub fn inner(&self) -> (*const sockaddr, socklen_t) { + match *self { + SocketAddr::V4(ref a) => { + (a as *const _ as *const _, mem::size_of_val(a) as socklen_t) + } + SocketAddr::V6(ref a) => { + (a as *const _ as *const _, mem::size_of_val(a) as socklen_t) + } + } + } +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SocketAddr::V4(ref a) => a.fmt(f), + SocketAddr::V6(ref a) => a.fmt(f), + } + } +} + +pub struct Socket(c_int); + +impl Socket { + pub fn open(family: AddressFamily, ty: SocketType, protocol: Protocol) -> Result { + let raw = cvt(unsafe { lwip_socket(family.into(), ty.into(), protocol.into()) })?; + + Ok(Socket(raw)) + } + + pub fn bind(&mut self, addr: *const sockaddr) -> Result<()> { + lwip!(unsafe { lwip_bind(self.0, addr, mem::size_of::() as _) }) + } + + pub fn close(&mut self) -> Result<()> { + lwip!(unsafe { lwip_close(self.0) }) + } + + pub fn shutdown(&mut self) -> Result<()> { + lwip!(unsafe { lwip_shutdown(self.0, 0) }) + } + + pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { + let mut storage: sockaddr_storage = unsafe { mem::zeroed() }; + let mut addrlen = mem::size_of_val(&storage) as socklen_t; + + let n = cvt(unsafe { + lwip_recvfrom( + self.0, + buf.as_mut_ptr() as *mut c_void, + buf.len() as _, + 0, + &mut storage as *mut _ as *mut _, + &mut addrlen, + ) + })?; + + Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?)) + } + + pub fn send_to(&self, buf: &mut [u8], dst: SocketAddr) -> Result { + let len = cmp::min(buf.len(), ::MAX as usize) as size_t; + let (dstp, dstlen) = dst.inner(); + + let n = cvt(unsafe { + lwip_sendto(self.0, buf.as_ptr() as *const c_void, len, 0, dstp, dstlen) + })?; + Ok(n as usize) + } +} + +impl Drop for Socket { + fn drop(&mut self) { + self.shutdown().ok(); + self.close().ok(); + } +} + +fn cvt(v: c_int) -> Result { + lwip_result!(v, v) +} + +fn sockaddr_to_addr(storage: &sockaddr_storage, len: usize) -> Result { + match storage.ss_family as u32 { + AF_INET => { + assert!(len as usize >= mem::size_of::()); + Ok(SocketAddr::V4(SocketAddrV4 { + inner: unsafe { *(storage as *const _ as *const sockaddr_in) }, + })) + } + AF_INET6 => { + assert!(len as usize >= mem::size_of::()); + Ok(SocketAddr::V6(SocketAddrV6 { + inner: unsafe { *(storage as *const _ as *const sockaddr_in6) }, + })) + } + _ => Err(LwIPError::from_raw(err_enum_t_ERR_VAL)), + } +} diff --git a/src/misc.rs b/src/misc.rs new file mode 100644 index 00000000000..2ac0fb0396f --- /dev/null +++ b/src/misc.rs @@ -0,0 +1,11 @@ +use esp_idf_sys::*; + +pub fn get_default_efuse_mac() -> Result<[u8; 6], EspError> { + let mut mac = [0; 6]; + unsafe { esp!(esp_efuse_mac_get_default(mac.as_mut_ptr()))? } + Ok(mac) +} + +pub fn restart() { + unsafe { esp_restart() }; +} diff --git a/src/private/cstr.rs b/src/private/cstr.rs index 50e2509442b..ddeda1d4c70 100644 --- a/src/private/cstr.rs +++ b/src/private/cstr.rs @@ -12,6 +12,7 @@ extern crate alloc; #[cfg(feature = "alloc")] pub fn set_str(buf: &mut [u8], s: &str) { + assert!(s.len() < buf.len()); let cs = CString::new(s).unwrap(); let ss: &[u8] = cs.as_bytes_with_nul(); buf[..ss.len()].copy_from_slice(ss); diff --git a/src/private/mod.rs b/src/private/mod.rs index 26ff2fafc46..cdd5f4eaa3a 100644 --- a/src/private/mod.rs +++ b/src/private/mod.rs @@ -1,6 +1,7 @@ pub mod common; pub mod cstr; pub mod net; +pub mod wait; mod stubs; diff --git a/src/private/wait.rs b/src/private/wait.rs new file mode 100644 index 00000000000..6608842f4ce --- /dev/null +++ b/src/private/wait.rs @@ -0,0 +1,97 @@ +use core::time::Duration; +#[cfg(feature = "std")] +use std::sync::{Condvar, Mutex}; + +use embedded_svc::mutex::Mutex as _; +#[cfg(not(feature = "std"))] +use esp_idf_sys::*; + +#[cfg(not(feature = "std"))] +use super::time::micros_since_boot; + +pub struct Waiter { + #[cfg(feature = "std")] + cvar: Condvar, + #[cfg(feature = "std")] + running: Mutex, + #[cfg(not(feature = "std"))] + running: EspMutex, +} + +impl Waiter { + pub fn new() -> Self { + Waiter { + #[cfg(feature = "std")] + cvar: Condvar::new(), + #[cfg(feature = "std")] + running: Mutex::new(false), + #[cfg(not(feature = "std"))] + running: EspMutex::new(false), + } + } + + pub fn start(&self) { + self.running.with_lock(|running| *running = true); + } + + #[cfg(feature = "std")] + pub fn wait(&self) { + if !self.running.with_lock(|running| *running) { + return; + } + + let _running = self + .cvar + .wait_while(self.running.lock().unwrap(), |running| *running) + .unwrap(); + } + + #[cfg(not(feature = "std"))] + pub fn wait(&self) { + while self.running.with_lock(|running| *running) { + unsafe { vTaskDelay(500) }; + } + } + + /// return = !timeout (= success) + #[cfg(feature = "std")] + pub fn wait_timeout(&self, dur: Duration) -> bool { + if !self.running.with_lock(|running| *running) { + return true; + } + + let (_running, res) = self + .cvar + .wait_timeout_while(self.running.lock().unwrap(), dur, |running| *running) + .unwrap(); + + return !res.timed_out(); + } + + /// return = !timeout (= success) + #[cfg(not(feature = "std"))] + pub fn wait_timeout(&self, dur: Duration) { + let now = micros_since_boot(); + let end = now + dur.as_micros(); + + while self.running.with_lock(|running| *running) { + if micros_since_boot() > end { + return false; + } + unsafe { vTaskDelay(500) }; + } + + return true; + } + + #[cfg(feature = "std")] + pub fn notify(&self) { + *self.running.lock().unwrap() = false; + self.cvar.notify_all(); + } + + #[cfg(not(feature = "std"))] + pub fn notify(&self) { + self.running.with_lock(|running| *running = false); + } +} diff --git a/src/task.rs b/src/task.rs new file mode 100644 index 00000000000..6fd23a535e3 --- /dev/null +++ b/src/task.rs @@ -0,0 +1,130 @@ +use core::time::Duration; + +use log::*; + +use esp_idf_sys::c_types::*; +use esp_idf_sys::*; + +use crate::private::cstr::CString; + +#[allow(non_upper_case_globals)] +const pdPASS: c_int = 1; + +pub struct TaskHandle(TaskHandle_t); + +impl TaskHandle { + pub fn stop(&self) { + unsafe { vTaskDelete(self.0) }; + } +} + +struct TaskInternal { + name: String, + f: Box anyhow::Result<()>>, +} + +pub struct TaskConfig { + stack_size: u32, + priority: u32, +} + +impl Default for TaskConfig { + fn default() -> Self { + TaskConfig { + stack_size: DEFAULT_THREAD_STACKSIZE, + priority: DEFAULT_THREAD_PRIO, + } + } +} + +impl TaskConfig { + pub fn new(stack_size: u32, priority: u32) -> Self { + TaskConfig { + stack_size, + priority, + } + } + + pub fn stack_size(self, stack_size: u32) -> Self { + TaskConfig { stack_size, ..self } + } + + pub fn priority(self, priority: u32) -> Self { + TaskConfig { priority, ..self } + } + + pub fn spawn( + self, + name: impl AsRef, + f: F, + ) -> Result + where + F: FnOnce() -> anyhow::Result<()>, + F: Send + 'static { + let parameters = TaskInternal { + name: name.as_ref().to_string(), + f: Box::new(f), + }; + let parameters = Box::into_raw(Box::new(parameters)) as *mut _; + + info!("starting task {:?}", name.as_ref()); + + let name = CString::new(name.as_ref()).unwrap(); + let mut handle: TaskHandle_t = core::ptr::null_mut(); + let res = unsafe { + xTaskCreatePinnedToCore( + Some(esp_idf_svc_task), + name.as_ptr(), + self.stack_size, + parameters, + self.priority, + &mut handle, + tskNO_AFFINITY as i32, + ) + }; + if res != pdPASS { + return Err(EspError::from(ESP_ERR_NO_MEM as i32).unwrap().into()); + } + + Ok(TaskHandle(handle)) + } +} + +pub fn spawn( + name: impl AsRef, + f: F, +) -> Result +where + F: FnOnce() -> anyhow::Result<()>, + F: Send + 'static, +{ + TaskConfig::default().spawn(name, f) +} + +extern "C" fn esp_idf_svc_task(args: *mut c_void) { + let internal = unsafe { Box::from_raw(args as *mut TaskInternal) }; + + info!("started task {:?}", internal.name); + + match (internal.f)() { + Err(e) => { + panic!("unexpected error in task {:?}: {:?}", internal.name, e); + } + Ok(_) => {} + } + + info!("destroying task {:?}", internal.name); + + unsafe { vTaskDelete(core::ptr::null_mut() as _) }; +} + +#[allow(non_upper_case_globals)] +pub const TICK_PERIOD_MS: u32 = 1000 / configTICK_RATE_HZ; + +/// sleep tells FreeRTOS to put the current thread to sleep for at least the specified duration, +/// this is not an exact duration and can't be shorter than the rtos tick period. +pub fn sleep(duration: Duration) { + unsafe { + vTaskDelay(duration.as_millis() as u32 / TICK_PERIOD_MS); + } +} diff --git a/src/time.rs b/src/time.rs new file mode 100644 index 00000000000..8b0ba70bbe2 --- /dev/null +++ b/src/time.rs @@ -0,0 +1,7 @@ +use core::convert::TryInto; + +use esp_idf_sys::*; + +pub fn micros_since_boot() -> u64 { + unsafe { esp_timer_get_time() }.try_into().unwrap() +}