From f21c1a34089e6a1545e00df93ac887895c04c986 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Thu, 30 Nov 2023 10:28:43 +0100 Subject: [PATCH] Split up edge-net into smaller crates --- Cargo.toml | 88 +- edge-captive/Cargo.toml | 14 + edge-captive/src/lib.rs | 110 ++ edge-captive/src/server.rs | 150 ++ edge-dhcp/Cargo.toml | 26 + src/asynch/dhcp.rs => edge-dhcp/src/asynch.rs | 6 +- src/dhcp.rs => edge-dhcp/src/lib.rs | 10 + edge-http/Cargo.toml | 24 + src/asynch/http.rs => edge-http/src/asynch.rs | 0 {src/asynch/http => edge-http/src}/client.rs | 2 +- edge-http/src/lib.rs | 1523 +++++++++++++++++ {src/asynch/http => edge-http/src}/server.rs | 4 +- edge-mdns/Cargo.toml | 12 + src/mdns.rs => edge-mdns/src/lib.rs | 2 + edge-mqtt/Cargo.toml | 12 + src/asynch/rumqttc.rs => edge-mqtt/src/lib.rs | 2 + edge-tcp/Cargo.toml | 19 + edge-tcp/README.md | 1 + src/asynch/tcp.rs => edge-tcp/src/lib.rs | 9 +- edge-ws/Cargo.toml | 14 + src/asynch/ws.rs => edge-ws/src/lib.rs | 71 +- src/asynch.rs | 84 +- src/asynch/io.rs | 102 -- src/captive.rs | 259 --- src/lib.rs | 28 +- src/{std_mutex.rs => std.rs} | 3 + src/{asynch/stdnal.rs => std/nal.rs} | 8 +- src/utils.rs | 4 - 28 files changed, 2126 insertions(+), 461 deletions(-) create mode 100644 edge-captive/Cargo.toml create mode 100644 edge-captive/src/lib.rs create mode 100644 edge-captive/src/server.rs create mode 100644 edge-dhcp/Cargo.toml rename src/asynch/dhcp.rs => edge-dhcp/src/asynch.rs (99%) rename src/dhcp.rs => edge-dhcp/src/lib.rs (99%) create mode 100644 edge-http/Cargo.toml rename src/asynch/http.rs => edge-http/src/asynch.rs (100%) rename {src/asynch/http => edge-http/src}/client.rs (99%) create mode 100644 edge-http/src/lib.rs rename {src/asynch/http => edge-http/src}/server.rs (99%) create mode 100644 edge-mdns/Cargo.toml rename src/mdns.rs => edge-mdns/src/lib.rs (99%) create mode 100644 edge-mqtt/Cargo.toml rename src/asynch/rumqttc.rs => edge-mqtt/src/lib.rs (99%) create mode 100644 edge-tcp/Cargo.toml create mode 100644 edge-tcp/README.md rename src/asynch/tcp.rs => edge-tcp/src/lib.rs (95%) create mode 100644 edge-ws/Cargo.toml rename src/asynch/ws.rs => edge-ws/src/lib.rs (89%) delete mode 100644 src/asynch/io.rs delete mode 100644 src/captive.rs rename src/{std_mutex.rs => std.rs} (89%) rename src/{asynch/stdnal.rs => std/nal.rs} (98%) delete mode 100644 src/utils.rs diff --git a/Cargo.toml b/Cargo.toml index bfa6332..62efed5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,38 +12,50 @@ readme = "README.md" rust-version = "1.71" [features] -default = ["std"] +std = [ + "alloc", + "embassy-sync/std", + "embedded-svc?/std", + "edge-http?/std", + "edge-captive?/std", + "dep:edge-tcp", -std = ["alloc", "embedded-io/std", "embassy-sync/std", "embedded-svc?/std", "async-io", "libc", "futures-lite/std", "httparse/std", "domain?/std"] -alloc = ["embedded-io/alloc", "embedded-svc?/alloc", "embedded-io-async?/alloc", ] -nightly = ["embedded-io-async", "embassy-sync/nightly", "embedded-svc?/nightly"] -embassy-util = [] # Just for backward compat. To be removed in future + "futures-lite/std", + "dep:async-io", + "dep:embedded-io-async", + "dep:embedded-nal-async", + "dep:libc" +] +alloc = ["edge-http?/alloc", "embedded-svc?/alloc"] +nightly = ["dep:edge-http", "embedded-svc?/nightly"] [dependencies] -embedded-io = { version = "0.6", default-features = false } -embedded-io-async = { version = "0.6", default-features = false, optional = true } -heapless = { version = "0.7", default-features = false } -httparse = { version = "1.7", default-features = false } -num_enum = { version = "0.7", default-features = false } -embassy-futures = "0.1" -embassy-time = "0.1" -embassy-sync = "0.3" -no-std-net = { version = "0.6", default-features = false } -rand_core = "0.6" -log = { version = "0.4", default-features = false } -base64 = { version = "0.13", default-features = false } -sha1_smol = { version = "1", default-features = false } -embedded-nal-async = "0.6" -domain = { version = "0.7", default-features = false, optional = true } +embedded-svc = { version = "0.26", default-features = false, optional = true, features = ["embedded-io-async"] } + futures-lite = { version = "1", default-features = false, optional = true } async-io = { version = "2", default-features = false, optional = true } +embedded-io-async = { workspace = true, optional = true } +embedded-nal-async = { workspace = true, optional = true } libc = { version = "0.2", default-features = false, optional = true } -embedded-svc = { version = "0.26", default-features = false, optional = true, features = ["embedded-io-async"] } -rumqttc = { version = "0.19", optional = true } + +embassy-sync.workspace = true +log.workspace = true +heapless.workspace = true +no-std-net.workspace = true + +edge-captive = { workspace = true, optional = true } +edge-dhcp = { workspace = true, optional = true } +edge-http = { workspace = true, optional = true } +edge-mdns = { workspace = true, optional = true } +edge-mqtt = { workspace = true, optional = true } +edge-ws = { workspace = true, optional = true } + +# edge-tcp is an exception that only contains trait definitions +edge-tcp = { workspace = true, optional = true } [dev-dependencies] anyhow = "1" -simple_logger="2.2" +simple_logger = "2.2" [[example]] name = "captive_portal" @@ -60,3 +72,33 @@ required-features = ["std", "nightly"] [[example]] name = "http_server" required-features = ["std", "nightly"] + +[workspace] +members = [ + ".", + "edge-captive", + "edge-dhcp", + "edge-http", + "edge-mdns", + "edge-mqtt", + "edge-tcp", + "edge-ws" +] + +[workspace.dependencies] +embassy-futures = "0.1" +embassy-sync = "0.3" +embedded-io = { version = "0.6", default-features = false } +embedded-io-async = { version = "0.6", default-features = false } +embedded-nal-async = "0.6" +log = { version = "0.4", default-features = false } +heapless = { version = "0.7", default-features = false } +no-std-net = { version = "0.6", default-features = false } + +edge-captive = { version = "0.1.0", path = "edge-captive" } +edge-dhcp = { version = "0.1.0", path = "edge-dhcp" } +edge-http = { version = "0.1.0", path = "edge-http" } +edge-mdns = { version = "0.1.0", path = "edge-mdns" } +edge-mqtt = { version = "0.1.0", path = "edge-mqtt" } +edge-tcp = { version = "0.1.0", path = "edge-tcp" } +edge-ws = { version = "0.1.0", path = "edge-ws" } diff --git a/edge-captive/Cargo.toml b/edge-captive/Cargo.toml new file mode 100644 index 0000000..c613d6b --- /dev/null +++ b/edge-captive/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "edge-captive" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +std = ["domain/std"] + +[dependencies] +log.workspace = true + +domain = { version = "0.7", default-features = false } diff --git a/edge-captive/src/lib.rs b/edge-captive/src/lib.rs new file mode 100644 index 0000000..2127be0 --- /dev/null +++ b/edge-captive/src/lib.rs @@ -0,0 +1,110 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +use core::fmt; +use core::time::Duration; + +use log::debug; + +use domain::{ + base::{ + iana::{Class, Opcode, Rcode}, + octets::*, + Record, Rtype, + }, + rdata::A, +}; + +#[cfg(feature = "std")] +mod server; + +#[cfg(feature = "std")] +pub use server::*; + +#[derive(Debug)] +pub struct InnerError(T); + +#[derive(Debug)] +pub enum DnsError { + ShortBuf(InnerError), + ParseError(InnerError), +} + +impl fmt::Display for DnsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DnsError::ShortBuf(e) => e.0.fmt(f), + DnsError::ParseError(e) => e.0.fmt(f), + } + } +} + +impl From for DnsError { + fn from(e: ShortBuf) -> Self { + Self::ShortBuf(InnerError(e)) + } +} + +impl From for DnsError { + fn from(e: ParseError) -> Self { + Self::ParseError(InnerError(e)) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for DnsError {} + +pub fn process_dns_request( + request: impl AsRef<[u8]>, + ip: &[u8; 4], + ttl: Duration, +) -> Result, DnsError> { + let request = request.as_ref(); + let response = Octets512::new(); + + let message = domain::base::Message::from_octets(request)?; + debug!("Processing message with header: {:?}", message.header()); + + let mut responseb = domain::base::MessageBuilder::from_target(response)?; + + let response = if matches!(message.header().opcode(), Opcode::Query) { + debug!("Message is of type Query, processing all questions"); + + let mut answerb = responseb.start_answer(&message, Rcode::NoError)?; + + for question in message.question() { + let question = question?; + + if matches!(question.qtype(), Rtype::A) { + debug!( + "Question {:?} is of type A, answering with IP {:?}, TTL {:?}", + question, ip, ttl + ); + + let record = Record::new( + question.qname(), + Class::In, + ttl.as_secs() as u32, + A::from_octets(ip[0], ip[1], ip[2], ip[3]), + ); + debug!("Answering question {:?} with {:?}", question, record); + answerb.push(record)?; + } else { + debug!("Question {:?} is not of type A, not answering", question); + } + } + + answerb.finish() + } else { + debug!("Message is not of type Query, replying with NotImp"); + + let headerb = responseb.header_mut(); + + headerb.set_id(message.header().id()); + headerb.set_opcode(message.header().opcode()); + headerb.set_rd(message.header().rd()); + headerb.set_rcode(domain::base::iana::Rcode::NotImp); + + responseb.finish() + }; + Ok(response) +} diff --git a/edge-captive/src/server.rs b/edge-captive/src/server.rs new file mode 100644 index 0000000..c56b274 --- /dev/null +++ b/edge-captive/src/server.rs @@ -0,0 +1,150 @@ +use std::{ + io, mem, + net::{Ipv4Addr, SocketAddrV4, UdpSocket}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + thread::{self, JoinHandle}, + time::Duration, +}; + +use log::*; + +#[derive(Clone, Debug)] +pub struct DnsConf { + pub bind_ip: Ipv4Addr, + pub bind_port: u16, + pub ip: Ipv4Addr, + pub ttl: Duration, +} + +impl DnsConf { + pub fn new(ip: Ipv4Addr) -> Self { + Self { + bind_ip: Ipv4Addr::new(0, 0, 0, 0), + bind_port: 53, + ip, + ttl: Duration::from_secs(60), + } + } +} + +#[derive(Debug)] +pub enum Status { + Stopped, + Started, + Error(io::Error), +} + +pub struct DnsServer { + conf: DnsConf, + status: Status, + running: Arc, + handle: Option>>, +} + +impl DnsServer { + pub fn new(conf: DnsConf) -> Self { + Self { + conf, + status: Status::Stopped, + running: Arc::new(AtomicBool::new(false)), + handle: None, + } + } + + pub fn get_status(&mut self) -> &Status { + self.cleanup(); + &self.status + } + + pub fn start(&mut self) -> Result<(), io::Error> { + if matches!(self.get_status(), Status::Started) { + return Ok(()); + } + let socket_address = SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port); + let running = self.running.clone(); + let ip = self.conf.ip; + let ttl = self.conf.ttl; + + self.running.store(true, Ordering::Relaxed); + self.handle = Some( + thread::Builder::new() + // default stack size is not enough + // 9000 was found via trial and error + .stack_size(9000) + .spawn(move || { + // Socket is not movable across thread bounds + // Otherwise we run into an assertion error here: https://github.com/espressif/esp-idf/blob/master/components/lwip/port/esp32/freertos/sys_arch.c#L103 + let socket = UdpSocket::bind(socket_address)?; + socket.set_read_timeout(Some(Duration::from_secs(1)))?; + let result = Self::run(&running, ip, ttl, socket); + + running.store(false, Ordering::Relaxed); + + result + }) + .unwrap(), + ); + + Ok(()) + } + + pub fn stop(&mut self) -> Result<(), io::Error> { + if matches!(self.get_status(), Status::Stopped) { + return Ok(()); + } + + self.running.store(false, Ordering::Relaxed); + self.cleanup(); + + let mut status = Status::Stopped; + mem::swap(&mut self.status, &mut status); + + match status { + Status::Error(e) => Err(e), + _ => Ok(()), + } + } + + fn cleanup(&mut self) { + if !self.running.load(Ordering::Relaxed) && self.handle.is_some() { + self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() { + Ok(_) => Status::Stopped, + Err(e) => Status::Error(e), + }; + } + } + + fn run( + running: &AtomicBool, + ip: Ipv4Addr, + ttl: Duration, + socket: UdpSocket, + ) -> Result<(), io::Error> { + while running.load(Ordering::Relaxed) { + let mut request_arr = [0_u8; 512]; + debug!("Waiting for data"); + let (request_len, source_addr) = match socket.recv_from(&mut request_arr) { + Ok(value) => value, + Err(err) => match err.kind() { + std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue, + _ => return Err(err), + }, + }; + + let request = &request_arr[..request_len]; + + debug!("Received {} bytes from {}", request.len(), source_addr); + let response = super::process_dns_request(request, &ip.octets(), ttl) + .map_err(|_| io::ErrorKind::Other)?; + + socket.send_to(response.as_ref(), source_addr)?; + + debug!("Sent {} bytes to {}", response.as_ref().len(), source_addr); + } + + Ok(()) + } +} diff --git a/edge-dhcp/Cargo.toml b/edge-dhcp/Cargo.toml new file mode 100644 index 0000000..5904e17 --- /dev/null +++ b/edge-dhcp/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "edge-dhcp" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = [] +nightly = [ + "dep:embassy-futures", + "dep:edge-tcp", + "dep:embedded-nal-async", +] + +[dependencies] +no-std-net.workspace = true +heapless = { workspace = true } +log.workspace = true +embassy-time = { version = "0.1", default-features = false } +rand_core = "0.6" + +edge-tcp = { workspace = true, optional = true } +embassy-futures = { workspace = true, optional = true } +embedded-nal-async = { workspace = true, optional = true } +num_enum = { version = "0.7", default-features = false } diff --git a/src/asynch/dhcp.rs b/edge-dhcp/src/asynch.rs similarity index 99% rename from src/asynch/dhcp.rs rename to edge-dhcp/src/asynch.rs index 796cc66..a88a784 100644 --- a/src/asynch/dhcp.rs +++ b/edge-dhcp/src/asynch.rs @@ -2,9 +2,9 @@ use core::fmt::Debug; use embedded_nal_async::{SocketAddr, SocketAddrV4, UdpStack, UnconnectedUdp}; -use crate::dhcp; +use crate as dhcp; -use super::tcp::{RawSocket, RawStack}; +use edge_tcp::{RawSocket, RawStack}; #[derive(Debug)] pub enum Error { @@ -231,7 +231,7 @@ pub mod client { pub use super::*; - pub use crate::dhcp::Settings; + pub use crate::Settings; #[derive(Clone, Debug)] pub struct Configuration { diff --git a/src/dhcp.rs b/edge-dhcp/src/lib.rs similarity index 99% rename from src/dhcp.rs rename to edge-dhcp/src/lib.rs index 9d628c2..715a92b 100644 --- a/src/dhcp.rs +++ b/edge-dhcp/src/lib.rs @@ -1,3 +1,10 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(stable_features)] +#![allow(unknown_lints)] +#![feature(async_fn_in_trait)] +#![allow(async_fn_in_trait)] +#![feature(impl_trait_projections)] + /// This code is a `no_std` and no-alloc modification of https://github.com/krolaw/dhcp4r use core::str::Utf8Error; @@ -5,6 +12,9 @@ use no_std_net::Ipv4Addr; use num_enum::TryFromPrimitive; +#[cfg(feature = "nightly")] +pub mod asynch; + use self::raw_ip::{Ipv4PacketHeader, UdpPacketHeader}; #[derive(Debug)] diff --git a/edge-http/Cargo.toml b/edge-http/Cargo.toml new file mode 100644 index 0000000..617aa3d --- /dev/null +++ b/edge-http/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "edge-http" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +std = ["httparse/std", "embedded-io/std", "embassy-sync/std"] +alloc = ["embedded-io/alloc", "embedded-io-async/alloc"] + +[dependencies] +embassy-sync = { workspace = true, features = ["nightly"] } +embedded-io.workspace = true +embedded-io-async.workspace = true +embedded-nal-async.workspace = true +heapless.workspace = true +log.workspace = true +no-std-net.workspace = true + +embassy-futures.workspace = true +httparse = { version = "1.7", default-features = false } + +edge-tcp = { version = "0.1.0", path = "../edge-tcp" } diff --git a/src/asynch/http.rs b/edge-http/src/asynch.rs similarity index 100% rename from src/asynch/http.rs rename to edge-http/src/asynch.rs diff --git a/src/asynch/http/client.rs b/edge-http/src/client.rs similarity index 99% rename from src/asynch/http/client.rs rename to edge-http/src/client.rs index d039c92..85dd606 100644 --- a/src/asynch/http/client.rs +++ b/edge-http/src/client.rs @@ -4,7 +4,7 @@ use embedded_io::ErrorType; use embedded_io_async::{Read, Write}; use no_std_net::SocketAddr; -use crate::asynch::http::{ +use crate::{ send_headers, send_headers_end, send_request, Body, BodyType, Error, ResponseHeaders, SendBody, }; use embedded_nal_async::TcpConnect; diff --git a/edge-http/src/lib.rs b/edge-http/src/lib.rs new file mode 100644 index 0000000..33db836 --- /dev/null +++ b/edge-http/src/lib.rs @@ -0,0 +1,1523 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(stable_features)] +#![allow(unknown_lints)] +#![feature(async_fn_in_trait)] +#![allow(async_fn_in_trait)] +#![feature(impl_trait_projections)] +#![feature(impl_trait_in_assoc_type)] + +use core::cmp::min; +use core::fmt::{Display, Write as _}; +use core::str; + +use embedded_io::ErrorType; +use embedded_io_async::{Read, Write}; + +use httparse::{Header, Status, EMPTY_HEADER}; + +use log::trace; + +#[allow(unused_imports)] +#[cfg(feature = "embedded-svc")] +pub use embedded_svc_compat::*; + +#[cfg(feature = "nightly")] +pub mod asynch; + +pub mod client; +pub mod server; + +/// An error in parsing the headers or the body. +#[derive(Debug)] +pub enum Error { + InvalidHeaders, + InvalidBody, + TooManyHeaders, + TooLongHeaders, + TooLongBody, + IncompleteHeaders, + IncompleteBody, + InvalidState, + Io(E), +} + +impl From for Error { + fn from(e: httparse::Error) -> Self { + match e { + httparse::Error::HeaderName => Self::InvalidHeaders, + httparse::Error::HeaderValue => Self::InvalidHeaders, + httparse::Error::NewLine => Self::InvalidHeaders, + httparse::Error::Status => Self::InvalidHeaders, + httparse::Error::Token => Self::InvalidHeaders, + httparse::Error::TooManyHeaders => Self::TooManyHeaders, + httparse::Error::Version => Self::InvalidHeaders, + } + } +} + +impl embedded_io::Error for Error +where + E: embedded_io::Error, +{ + fn kind(&self) -> embedded_io::ErrorKind { + match self { + Self::Io(e) => e.kind(), + _ => embedded_io::ErrorKind::Other, + } + } +} + +impl Display for Error +where + E: Display, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::InvalidHeaders => write!(f, "Invalid HTTP headers or status line"), + Self::InvalidBody => write!(f, "Invalid HTTP body"), + Self::TooManyHeaders => write!(f, "Too many HTTP headers"), + Self::TooLongHeaders => write!(f, "HTTP headers section is too long"), + Self::TooLongBody => write!(f, "HTTP body is too long"), + Self::IncompleteHeaders => write!(f, "HTTP headers section is incomplete"), + Self::IncompleteBody => write!(f, "HTTP body is incomplete"), + Self::InvalidState => write!(f, "Connection is not in requested state"), + Self::Io(e) => write!(f, "{e}"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error where E: std::error::Error {} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "std", derive(Hash))] +pub enum Method { + Delete, + Get, + Head, + Post, + Put, + Connect, + Options, + Trace, + Copy, + Lock, + MkCol, + Move, + Propfind, + Proppatch, + Search, + Unlock, + Bind, + Rebind, + Unbind, + Acl, + Report, + MkActivity, + Checkout, + Merge, + MSearch, + Notify, + Subscribe, + Unsubscribe, + Patch, + Purge, + MkCalendar, + Link, + Unlink, +} + +impl Method { + pub fn new(method: &str) -> Option { + if method.eq_ignore_ascii_case("Delete") { + Some(Self::Delete) + } else if method.eq_ignore_ascii_case("Get") { + Some(Self::Get) + } else if method.eq_ignore_ascii_case("Head") { + Some(Self::Head) + } else if method.eq_ignore_ascii_case("Post") { + Some(Self::Post) + } else if method.eq_ignore_ascii_case("Put") { + Some(Self::Put) + } else if method.eq_ignore_ascii_case("Connect") { + Some(Self::Connect) + } else if method.eq_ignore_ascii_case("Options") { + Some(Self::Options) + } else if method.eq_ignore_ascii_case("Trace") { + Some(Self::Trace) + } else if method.eq_ignore_ascii_case("Copy") { + Some(Self::Copy) + } else if method.eq_ignore_ascii_case("Lock") { + Some(Self::Lock) + } else if method.eq_ignore_ascii_case("MkCol") { + Some(Self::MkCol) + } else if method.eq_ignore_ascii_case("Move") { + Some(Self::Move) + } else if method.eq_ignore_ascii_case("Propfind") { + Some(Self::Propfind) + } else if method.eq_ignore_ascii_case("Proppatch") { + Some(Self::Proppatch) + } else if method.eq_ignore_ascii_case("Search") { + Some(Self::Search) + } else if method.eq_ignore_ascii_case("Unlock") { + Some(Self::Unlock) + } else if method.eq_ignore_ascii_case("Bind") { + Some(Self::Bind) + } else if method.eq_ignore_ascii_case("Rebind") { + Some(Self::Rebind) + } else if method.eq_ignore_ascii_case("Unbind") { + Some(Self::Unbind) + } else if method.eq_ignore_ascii_case("Acl") { + Some(Self::Acl) + } else if method.eq_ignore_ascii_case("Report") { + Some(Self::Report) + } else if method.eq_ignore_ascii_case("MkActivity") { + Some(Self::MkActivity) + } else if method.eq_ignore_ascii_case("Checkout") { + Some(Self::Checkout) + } else if method.eq_ignore_ascii_case("Merge") { + Some(Self::Merge) + } else if method.eq_ignore_ascii_case("MSearch") { + Some(Self::MSearch) + } else if method.eq_ignore_ascii_case("Notify") { + Some(Self::Notify) + } else if method.eq_ignore_ascii_case("Subscribe") { + Some(Self::Subscribe) + } else if method.eq_ignore_ascii_case("Unsubscribe") { + Some(Self::Unsubscribe) + } else if method.eq_ignore_ascii_case("Patch") { + Some(Self::Patch) + } else if method.eq_ignore_ascii_case("Purge") { + Some(Self::Purge) + } else if method.eq_ignore_ascii_case("MkCalendar") { + Some(Self::MkCalendar) + } else if method.eq_ignore_ascii_case("Link") { + Some(Self::Link) + } else if method.eq_ignore_ascii_case("Unlink") { + Some(Self::Unlink) + } else { + None + } + } + + fn as_str(&self) -> &'static str { + match self { + Self::Delete => "DELETE", + Self::Get => "GET", + Self::Head => "HEAD", + Self::Post => "POST", + Self::Put => "PUT", + Self::Connect => "CONNECT", + Self::Options => "OPTIONS", + Self::Trace => "TRACE", + Self::Copy => "COPY", + Self::Lock => "LOCK", + Self::MkCol => "MKCOL", + Self::Move => "MOVE", + Self::Propfind => "PROPFIND", + Self::Proppatch => "PROPPATCH", + Self::Search => "SEARCH", + Self::Unlock => "UNLOCK", + Self::Bind => "BIND", + Self::Rebind => "REBIND", + Self::Unbind => "UNBIND", + Self::Acl => "ACL", + Self::Report => "REPORT", + Self::MkActivity => "MKACTIVITY", + Self::Checkout => "CHECKOUT", + Self::Merge => "MERGE", + Self::MSearch => "MSEARCH", + Self::Notify => "NOTIFY", + Self::Subscribe => "SUBSCRIBE", + Self::Unsubscribe => "UNSUBSCRIBE", + Self::Patch => "PATCH", + Self::Purge => "PURGE", + Self::MkCalendar => "MKCALENDAR", + Self::Link => "LINK", + Self::Unlink => "UNLINK", + } + } +} + +impl Display for Method { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +pub async fn send_request( + method: Option, + path: Option<&str>, + output: W, +) -> Result<(), Error> +where + W: Write, +{ + send_status_line(true, method.map(|method| method.as_str()), path, output).await +} + +pub async fn send_status( + status: Option, + reason: Option<&str>, + output: W, +) -> Result<(), Error> +where + W: Write, +{ + let status_str = status.map(heapless::String::<5>::from); + + send_status_line( + false, + status_str.as_ref().map(|status| status.as_str()), + reason, + output, + ) + .await +} + +pub async fn send_headers<'a, H, W>(headers: H, output: W) -> Result> +where + W: Write, + H: IntoIterator, +{ + send_raw_headers( + headers + .into_iter() + .map(|(name, value)| (*name, value.as_bytes())), + output, + ) + .await +} + +pub async fn send_raw_headers<'a, H, W>( + headers: H, + mut output: W, +) -> Result> +where + W: Write, + H: IntoIterator, +{ + let mut body = BodyType::Unknown; + + for (name, value) in headers.into_iter() { + if body == BodyType::Unknown { + body = BodyType::from_header(name, unsafe { str::from_utf8_unchecked(value) }); + } + + output.write_all(name.as_bytes()).await.map_err(Error::Io)?; + output.write_all(b": ").await.map_err(Error::Io)?; + output.write_all(value).await.map_err(Error::Io)?; + output.write_all(b"\r\n").await.map_err(Error::Io)?; + } + + Ok(body) +} + +pub async fn send_headers_end(mut output: W) -> Result<(), Error> +where + W: Write, +{ + output.write_all(b"\r\n").await.map_err(Error::Io) +} + +#[derive(Debug)] +pub struct Headers<'b, const N: usize = 64>([httparse::Header<'b>; N]); + +impl<'b, const N: usize> Headers<'b, N> { + pub const fn new() -> Self { + Self([httparse::EMPTY_HEADER; N]) + } + + pub fn content_len(&self) -> Option { + self.get("Content-Length") + .map(|content_len_str| content_len_str.parse::().unwrap()) + } + + pub fn content_type(&self) -> Option<&str> { + self.get("Content-Type") + } + + pub fn content_encoding(&self) -> Option<&str> { + self.get("Content-Encoding") + } + + pub fn transfer_encoding(&self) -> Option<&str> { + self.get("Transfer-Encoding") + } + + pub fn host(&self) -> Option<&str> { + self.get("Host") + } + + pub fn connection(&self) -> Option<&str> { + self.get("Connection") + } + + pub fn cache_control(&self) -> Option<&str> { + self.get("Cache-Control") + } + + pub fn upgrade(&self) -> Option<&str> { + self.get("Upgrade") + } + + pub fn iter(&self) -> impl Iterator { + self.iter_raw() + .map(|(name, value)| (name, unsafe { str::from_utf8_unchecked(value) })) + } + + pub fn iter_raw(&self) -> impl Iterator { + self.0 + .iter() + .filter(|header| !header.name.is_empty()) + .map(|header| (header.name, header.value)) + } + + pub fn get(&self, name: &str) -> Option<&str> { + self.iter() + .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) + .map(|(_, value)| value) + } + + pub fn get_raw(&self, name: &str) -> Option<&[u8]> { + self.iter_raw() + .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) + .map(|(_, value)| value) + } + + pub fn set(&mut self, name: &'b str, value: &'b str) -> &mut Self { + self.set_raw(name, value.as_bytes()) + } + + pub fn set_raw(&mut self, name: &'b str, value: &'b [u8]) -> &mut Self { + if !name.is_empty() { + for header in &mut self.0 { + if header.name.is_empty() || header.name.eq_ignore_ascii_case(name) { + *header = Header { name, value }; + return self; + } + } + } + + panic!("No space left"); + } + + pub fn remove(&mut self, name: &str) -> &mut Self { + let index = self + .0 + .iter() + .enumerate() + .find(|(_, header)| header.name.eq_ignore_ascii_case(name)); + + if let Some((mut index, _)) = index { + while index < self.0.len() - 1 { + self.0[index] = self.0[index + 1]; + + index += 1; + } + + self.0[index] = EMPTY_HEADER; + } + + self + } + + pub fn set_content_len( + &mut self, + content_len: u64, + buf: &'b mut heapless::String<20>, + ) -> &mut Self { + *buf = heapless::String::<20>::from(content_len); + + self.set("Content-Length", buf.as_str()) + } + + pub fn set_content_type(&mut self, content_type: &'b str) -> &mut Self { + self.set("Content-Type", content_type) + } + + pub fn set_content_encoding(&mut self, content_encoding: &'b str) -> &mut Self { + self.set("Content-Encoding", content_encoding) + } + + pub fn set_transfer_encoding(&mut self, transfer_encoding: &'b str) -> &mut Self { + self.set("Transfer-Encoding", transfer_encoding) + } + + pub fn set_transfer_encoding_chunked(&mut self) -> &mut Self { + self.set_transfer_encoding("Chunked") + } + + pub fn set_host(&mut self, host: &'b str) -> &mut Self { + self.set("Host", host) + } + + pub fn set_connection(&mut self, connection: &'b str) -> &mut Self { + self.set("Connection", connection) + } + + pub fn set_connection_close(&mut self) -> &mut Self { + self.set_connection("Close") + } + + pub fn set_connection_keep_alive(&mut self) -> &mut Self { + self.set_connection("Keep-Alive") + } + + pub fn set_connection_upgrade(&mut self) -> &mut Self { + self.set_connection("Upgrade") + } + + pub fn set_cache_control(&mut self, cache: &'b str) -> &mut Self { + self.set("Cache-Control", cache) + } + + pub fn set_cache_control_no_cache(&mut self) -> &mut Self { + self.set_cache_control("No-Cache") + } + + pub fn set_upgrade(&mut self, upgrade: &'b str) -> &mut Self { + self.set("Upgrade", upgrade) + } + + pub fn set_upgrade_websocket(&mut self) -> &mut Self { + self.set_upgrade("websocket") + } + + pub async fn send(&self, output: W) -> Result> + where + W: Write, + { + send_raw_headers(self.iter_raw(), output).await + } +} + +impl<'b, const N: usize> Default for Headers<'b, N> { + fn default() -> Self { + Self::new() + } +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum BodyType { + Chunked, + ContentLen(u64), + Close, + Unknown, +} + +impl BodyType { + pub fn from_header(name: &str, value: &str) -> Self { + if "Transfer-Encoding".eq_ignore_ascii_case(name) { + if value.eq_ignore_ascii_case("Chunked") { + return Self::Chunked; + } + } else if "Content-Length".eq_ignore_ascii_case(name) { + return Self::ContentLen(value.parse::().unwrap()); // TODO + } else if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Close") { + return Self::Close; + } + + Self::Unknown + } + + pub fn from_headers<'a, H>(headers: H) -> Self + where + H: IntoIterator, + { + for (name, value) in headers { + let body = Self::from_header(name, value); + + if body != Self::Unknown { + return body; + } + } + + Self::Unknown + } +} + +pub enum Body<'b, R> { + Close(PartiallyRead<'b, R>), + ContentLen(ContentLenRead>), + Chunked(ChunkedRead<'b, PartiallyRead<'b, R>>), +} + +impl<'b, R> Body<'b, R> +where + R: Read, +{ + pub fn new(body_type: BodyType, buf: &'b mut [u8], read_len: usize, input: R) -> Self { + match body_type { + BodyType::Chunked => Body::Chunked(ChunkedRead::new( + PartiallyRead::new(&[], input), + buf, + read_len, + )), + BodyType::ContentLen(content_len) => Body::ContentLen(ContentLenRead::new( + content_len, + PartiallyRead::new(&buf[..read_len], input), + )), + BodyType::Close => Body::Close(PartiallyRead::new(&buf[..read_len], input)), + BodyType::Unknown => Body::ContentLen(ContentLenRead::new( + 0, + PartiallyRead::new(&buf[..read_len], input), + )), + } + } + + pub fn is_complete(&self) -> bool { + match self { + Self::Close(_) => true, + Self::ContentLen(r) => r.is_complete(), + Self::Chunked(r) => r.is_complete(), + } + } + + pub fn as_raw_reader(&mut self) -> &mut R { + match self { + Self::Close(r) => &mut r.input, + Self::ContentLen(r) => &mut r.input.input, + Self::Chunked(r) => &mut r.input.input, + } + } + + pub fn release(self) -> R { + match self { + Self::Close(r) => r.release(), + Self::ContentLen(r) => r.release().release(), + Self::Chunked(r) => r.release().release(), + } + } +} + +impl<'b, R> ErrorType for Body<'b, R> +where + R: ErrorType, +{ + type Error = Error; +} + +impl<'b, R> Read for Body<'b, R> +where + R: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + match self { + Self::Close(read) => Ok(read.read(buf).await.map_err(Error::Io)?), + Self::ContentLen(read) => Ok(read.read(buf).await?), + Self::Chunked(read) => Ok(read.read(buf).await?), + } + } +} + +pub struct PartiallyRead<'b, R> { + buf: &'b [u8], + read_len: usize, + input: R, +} + +impl<'b, R> PartiallyRead<'b, R> { + pub const fn new(buf: &'b [u8], input: R) -> Self { + Self { + buf, + read_len: 0, + input, + } + } + + pub fn buf_len(&self) -> usize { + self.buf.len() + } + + pub fn as_raw_reader(&mut self) -> &mut R { + &mut self.input + } + + pub fn release(self) -> R { + self.input + } +} + +impl<'b, R> ErrorType for PartiallyRead<'b, R> +where + R: ErrorType, +{ + type Error = R::Error; +} + +impl<'b, R> Read for PartiallyRead<'b, R> +where + R: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + if self.buf.len() > self.read_len { + let len = min(buf.len(), self.buf.len() - self.read_len); + buf[..len].copy_from_slice(&self.buf[self.read_len..self.read_len + len]); + + self.read_len += len; + + Ok(len) + } else { + Ok(self.input.read(buf).await?) + } + } +} + +pub struct ContentLenRead { + content_len: u64, + read_len: u64, + input: R, +} + +impl ContentLenRead { + pub const fn new(content_len: u64, input: R) -> Self { + Self { + content_len, + read_len: 0, + input, + } + } + + pub fn is_complete(&self) -> bool { + self.content_len == self.read_len + } + + pub fn release(self) -> R { + self.input + } +} + +impl ErrorType for ContentLenRead +where + R: ErrorType, +{ + type Error = Error; +} + +impl Read for ContentLenRead +where + R: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + let len = min(buf.len() as _, self.content_len - self.read_len); + if len > 0 { + let read = self + .input + .read(&mut buf[..len as _]) + .await + .map_err(Error::Io)?; + self.read_len += read as u64; + + Ok(read) + } else { + Ok(0) + } + } +} + +pub struct ChunkedRead<'b, R> { + buf: &'b mut [u8], + buf_offset: usize, + buf_len: usize, + input: R, + remain: u64, + complete: bool, +} + +impl<'b, R> ChunkedRead<'b, R> +where + R: Read, +{ + pub fn new(input: R, buf: &'b mut [u8], buf_len: usize) -> Self { + Self { + buf, + buf_offset: 0, + buf_len, + input, + remain: 0, + complete: false, + } + } + + pub fn is_complete(&self) -> bool { + self.complete + } + + pub fn release(self) -> R { + self.input + } + + // The elegant pull parser taken from here: + // https://github.com/kchmck/uhttp_chunked_bytes.rs/blob/master/src/lib.rs + // Changes: + // - Converted to async + // - Iterators removed + // - Simpler error handling + // - Consumption of trailer + async fn next(&mut self) -> Result, Error> { + if self.complete { + return Ok(None); + } + + if self.remain == 0 { + if let Some(size) = self.parse_size().await? { + // If chunk size is zero (final chunk), the stream is finished [RFC7230§4.1]. + if size == 0 { + self.consume_trailer().await?; + self.complete = true; + return Ok(None); + } + + self.remain = size; + } else { + self.complete = true; + return Ok(None); + } + } + + let next = self.input_fetch().await?; + self.remain -= 1; + + // If current chunk is finished, verify it ends with CRLF [RFC7230§4.1]. + if self.remain == 0 { + self.consume_multi(b"\r\n").await?; + } + + Ok(Some(next)) + } + + // Parse the number of bytes in the next chunk. + async fn parse_size(&mut self) -> Result, Error> { + let mut digits = [0_u8; 16]; + + let slice = match self.parse_digits(&mut digits[..]).await? { + // This is safe because the following call to `from_str_radix` does + // its own verification on the bytes. + Some(s) => unsafe { str::from_utf8_unchecked(s) }, + None => return Ok(None), + }; + + let size = u64::from_str_radix(slice, 16).map_err(|_| Error::InvalidBody)?; + + Ok(Some(size)) + } + + // Extract the hex digits for the current chunk size. + async fn parse_digits<'a>( + &'a mut self, + digits: &'a mut [u8], + ) -> Result, Error> { + // Number of hex digits that have been extracted. + let mut len = 0; + + loop { + let b = match self.input_next().await? { + Some(b) => b, + None => { + return if len == 0 { + // If EOF at the beginning of a new chunk, the stream is finished. + Ok(None) + } else { + Err(Error::IncompleteBody) + }; + } + }; + + match b { + b'\r' => { + self.consume(b'\n').await?; + break; + } + b';' => { + self.consume_ext().await?; + break; + } + _ => { + match digits.get_mut(len) { + Some(d) => *d = b, + None => return Err(Error::InvalidBody), + } + + len += 1; + } + } + } + + Ok(Some(&digits[..len])) + } + + // Consume and discard current chunk extension. + // This doesn't check whether the characters up to CRLF actually have correct syntax. + async fn consume_ext(&mut self) -> Result<(), Error> { + self.consume_header().await?; + + Ok(()) + } + + // Consume and discard the optional trailer following the last chunk. + async fn consume_trailer(&mut self) -> Result<(), Error> { + while self.consume_header().await? {} + + Ok(()) + } + + // Consume and discard each header in the optional trailer following the last chunk. + async fn consume_header(&mut self) -> Result> { + let mut first = self.input_fetch().await?; + let mut len = 1; + + loop { + let second = self.input_fetch().await?; + len += 1; + + if first == b'\r' && second == b'\n' { + return Ok(len > 2); + } + + first = second; + } + } + + // Verify the next bytes in the stream match the expectation. + async fn consume_multi(&mut self, bytes: &[u8]) -> Result<(), Error> { + for byte in bytes { + self.consume(*byte).await?; + } + + Ok(()) + } + + // Verify the next byte in the stream is matching the expectation. + async fn consume(&mut self, byte: u8) -> Result<(), Error> { + if self.input_fetch().await? == byte { + Ok(()) + } else { + Err(Error::InvalidBody) + } + } + + async fn input_fetch(&mut self) -> Result> { + self.input_next().await?.ok_or(Error::IncompleteBody) + } + + async fn input_next(&mut self) -> Result, Error> { + if self.buf_offset == self.buf_len { + self.buf_len = self.input.read(self.buf).await.map_err(Error::Io)?; + self.buf_offset = 0; + } + + if self.buf_len > 0 { + let byte = self.buf[self.buf_offset]; + self.buf_offset += 1; + + Ok(Some(byte)) + } else { + Ok(None) + } + } +} + +impl<'b, R> ErrorType for ChunkedRead<'b, R> +where + R: ErrorType, +{ + type Error = Error; +} + +impl<'b, R> Read for ChunkedRead<'b, R> +where + R: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + for (index, byte_pos) in buf.iter_mut().enumerate() { + if let Some(byte) = self.next().await? { + *byte_pos = byte; + } else { + return Ok(index); + } + } + + Ok(buf.len()) + } +} + +pub enum SendBody { + Close(W), + ContentLen(ContentLenWrite), + Chunked(ChunkedWrite), +} + +impl SendBody +where + W: Write, +{ + pub fn new(body_type: BodyType, output: W) -> SendBody { + match body_type { + BodyType::Chunked => SendBody::Chunked(ChunkedWrite::new(output)), + BodyType::ContentLen(content_len) => { + SendBody::ContentLen(ContentLenWrite::new(content_len, output)) + } + BodyType::Close => SendBody::Close(output), + BodyType::Unknown => SendBody::ContentLen(ContentLenWrite::new(0, output)), + } + } + + pub fn is_complete(&self) -> bool { + match self { + Self::ContentLen(w) => w.is_complete(), + _ => true, + } + } + + pub async fn finish(&mut self) -> Result<(), Error> + where + W: Write, + { + match self { + Self::Close(_) => (), + Self::ContentLen(_) => (), + Self::Chunked(w) => w.finish().await?, + } + + self.flush().await?; + + Ok(()) + } + + pub fn as_raw_writer(&mut self) -> &mut W { + match self { + Self::Close(w) => w, + Self::ContentLen(w) => &mut w.output, + Self::Chunked(w) => &mut w.output, + } + } + + pub fn release(self) -> W { + match self { + Self::Close(w) => w, + Self::ContentLen(w) => w.release(), + Self::Chunked(w) => w.release(), + } + } +} + +impl ErrorType for SendBody +where + W: ErrorType, +{ + type Error = Error; +} + +impl Write for SendBody +where + W: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + match self { + Self::Close(w) => Ok(w.write(buf).await.map_err(Error::Io)?), + Self::ContentLen(w) => Ok(w.write(buf).await?), + Self::Chunked(w) => Ok(w.write(buf).await?), + } + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + match self { + Self::Close(w) => Ok(w.flush().await.map_err(Error::Io)?), + Self::ContentLen(w) => Ok(w.flush().await?), + Self::Chunked(w) => Ok(w.flush().await?), + } + } +} + +pub struct ContentLenWrite { + content_len: u64, + write_len: u64, + output: W, +} + +impl ContentLenWrite { + pub const fn new(content_len: u64, output: W) -> Self { + Self { + content_len, + write_len: 0, + output, + } + } + + pub fn is_complete(&self) -> bool { + self.content_len == self.write_len + } + + pub fn release(self) -> W { + self.output + } +} + +impl ErrorType for ContentLenWrite +where + W: ErrorType, +{ + type Error = Error; +} + +impl Write for ContentLenWrite +where + W: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + if self.content_len >= self.write_len + buf.len() as u64 { + let write = self.output.write(buf).await.map_err(Error::Io)?; + self.write_len += write as u64; + + Ok(write) + } else { + Err(Error::TooLongBody) + } + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.output.flush().await.map_err(Error::Io) + } +} + +pub struct ChunkedWrite { + output: W, +} + +impl ChunkedWrite { + pub const fn new(output: W) -> Self { + Self { output } + } + + pub async fn finish(&mut self) -> Result<(), Error> + where + W: Write, + { + self.output.write_all(b"\r\n").await.map_err(Error::Io) + } + + pub fn release(self) -> W { + self.output + } +} + +impl ErrorType for ChunkedWrite +where + W: ErrorType, +{ + type Error = Error; +} + +impl Write for ChunkedWrite +where + W: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + if !buf.is_empty() { + let mut len_str = heapless::String::<10>::new(); + write!(&mut len_str, "{:X}\r\n", buf.len()).unwrap(); + self.output + .write_all(len_str.as_bytes()) + .await + .map_err(Error::Io)?; + + self.output.write_all(buf).await.map_err(Error::Io)?; + self.output + .write_all("\r\n".as_bytes()) + .await + .map_err(Error::Io)?; + + Ok(buf.len()) + } else { + Ok(0) + } + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.output.flush().await.map_err(Error::Io) + } +} + +#[derive(Default, Debug)] +pub struct RequestHeaders<'b, const N: usize> { + pub method: Option, + pub path: Option<&'b str>, + pub headers: Headers<'b, N>, +} + +impl<'b, const N: usize> RequestHeaders<'b, N> { + pub const fn new() -> Self { + Self { + method: None, + path: None, + headers: Headers::::new(), + } + } + + pub async fn receive( + &mut self, + buf: &'b mut [u8], + mut input: R, + ) -> Result<(&'b mut [u8], usize), Error> + where + R: Read, + { + let (read_len, headers_len) = match read_reply_buf::(&mut input, buf, true).await { + Ok(read_len) => read_len, + Err(e) => return Err(e), + }; + + let mut parser = httparse::Request::new(&mut self.headers.0); + + let (headers_buf, body_buf) = buf.split_at_mut(headers_len); + + let status = match parser.parse(headers_buf) { + Ok(status) => status, + Err(e) => return Err(e.into()), + }; + + if let Status::Complete(headers_len2) = status { + if headers_len != headers_len2 { + unreachable!("Should not happen. HTTP header parsing is indeterminate.") + } + + self.method = parser.method.and_then(Method::new); + self.path = parser.path; + + trace!("Received:\n{}", self); + + Ok((body_buf, read_len - headers_len)) + } else { + unreachable!("Secondary parse of already loaded buffer failed.") + } + } + + pub async fn send(&self, mut output: W) -> Result> + where + W: Write, + { + send_request(self.method, self.path, &mut output).await?; + let body_type = self.headers.send(&mut output).await?; + send_headers_end(output).await?; + + Ok(body_type) + } +} + +impl<'b, const N: usize> Display for RequestHeaders<'b, N> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // if let Some(version) = self.version { + // writeln!(f, "Version {}", version)?; + // } + + if let Some(method) = self.method { + writeln!(f, "{} {}", method, self.path.unwrap_or(""))?; + } + + for (name, value) in self.headers.iter() { + if name.is_empty() { + break; + } + + writeln!(f, "{name}: {value}")?; + } + + Ok(()) + } +} + +#[derive(Default, Debug)] +pub struct ResponseHeaders<'b, const N: usize> { + pub code: Option, + pub reason: Option<&'b str>, + pub headers: Headers<'b, N>, +} + +impl<'b, const N: usize> ResponseHeaders<'b, N> { + pub const fn new() -> Self { + Self { + code: None, + reason: None, + headers: Headers::::new(), + } + } + + pub async fn receive( + &mut self, + buf: &'b mut [u8], + mut input: R, + ) -> Result<(&'b mut [u8], usize), Error> + where + R: Read, + { + let (read_len, headers_len) = read_reply_buf::(&mut input, buf, false).await?; + + let mut parser = httparse::Response::new(&mut self.headers.0); + + let (headers_buf, body_buf) = buf.split_at_mut(headers_len); + + let status = parser.parse(headers_buf).map_err(Error::from)?; + + if let Status::Complete(headers_len2) = status { + if headers_len != headers_len2 { + unreachable!("Should not happen. HTTP header parsing is indeterminate.") + } + + self.code = parser.code; + self.reason = parser.reason; + + trace!("Received:\n{}", self); + + Ok((body_buf, read_len - headers_len)) + } else { + unreachable!("Secondary parse of already loaded buffer failed.") + } + } + + pub async fn send(&self, mut output: W) -> Result> + where + W: Write, + { + send_status(self.code, self.reason, &mut output).await?; + let body_type = self.headers.send(&mut output).await?; + send_headers_end(output).await?; + + Ok(body_type) + } +} + +impl<'b, const N: usize> Display for ResponseHeaders<'b, N> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // if let Some(version) = self.version { + // writeln!(f, "Version {}", version)?; + // } + + if let Some(code) = self.code { + writeln!(f, "{} {}", code, self.reason.unwrap_or(""))?; + } + + for (name, value) in self.headers.iter() { + if name.is_empty() { + break; + } + + writeln!(f, "{name}: {value}")?; + } + + Ok(()) + } +} + +async fn read_reply_buf( + mut input: R, + buf: &mut [u8], + request: bool, +) -> Result<(usize, usize), Error> +where + R: Read, +{ + let mut offset = 0; + let mut size = 0; + + while buf.len() > size { + let read = input.read(&mut buf[offset..]).await.map_err(Error::Io)?; + + offset += read; + size += read; + + let mut headers = [httparse::EMPTY_HEADER; N]; + + let status = if request { + httparse::Request::new(&mut headers).parse(&buf[..size])? + } else { + httparse::Response::new(&mut headers).parse(&buf[..size])? + }; + + if let httparse::Status::Complete(headers_len) = status { + return Ok((size, headers_len)); + } + } + + Err(Error::TooManyHeaders) +} + +async fn send_status_line( + request: bool, + token: Option<&str>, + extra: Option<&str>, + mut output: W, +) -> Result<(), Error> +where + W: Write, +{ + let mut written = false; + + if !request { + output.write_all(b"HTTP/1.1").await.map_err(Error::Io)?; + written = true; + } + + if let Some(token) = token { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } + + output + .write_all(token.as_bytes()) + .await + .map_err(Error::Io)?; + + written = true; + } + + if let Some(extra) = extra { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } + + output + .write_all(extra.as_bytes()) + .await + .map_err(Error::Io)?; + + written = true; + } + + if request { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } + + output.write_all(b"HTTP/1.1").await.map_err(Error::Io)?; + } + + output.write_all(b"\r\n").await.map_err(Error::Io)?; + + Ok(()) +} + +#[cfg(feature = "embedded-svc")] +mod embedded_svc_compat { + use core::str; + + use embedded_svc::http::client::asynch::Method; + + impl From for super::Method { + fn from(method: Method) -> Self { + match method { + Method::Delete => super::Method::Delete, + Method::Get => super::Method::Get, + Method::Head => super::Method::Head, + Method::Post => super::Method::Post, + Method::Put => super::Method::Put, + Method::Connect => super::Method::Connect, + Method::Options => super::Method::Options, + Method::Trace => super::Method::Trace, + Method::Copy => super::Method::Copy, + Method::Lock => super::Method::Lock, + Method::MkCol => super::Method::MkCol, + Method::Move => super::Method::Move, + Method::Propfind => super::Method::Propfind, + Method::Proppatch => super::Method::Proppatch, + Method::Search => super::Method::Search, + Method::Unlock => super::Method::Unlock, + Method::Bind => super::Method::Bind, + Method::Rebind => super::Method::Rebind, + Method::Unbind => super::Method::Unbind, + Method::Acl => super::Method::Acl, + Method::Report => super::Method::Report, + Method::MkActivity => super::Method::MkActivity, + Method::Checkout => super::Method::Checkout, + Method::Merge => super::Method::Merge, + Method::MSearch => super::Method::MSearch, + Method::Notify => super::Method::Notify, + Method::Subscribe => super::Method::Subscribe, + Method::Unsubscribe => super::Method::Unsubscribe, + Method::Patch => super::Method::Patch, + Method::Purge => super::Method::Purge, + Method::MkCalendar => super::Method::MkCalendar, + Method::Link => super::Method::Link, + Method::Unlink => super::Method::Unlink, + } + } + } + + impl From for Method { + fn from(method: super::Method) -> Self { + match method { + super::Method::Delete => Method::Delete, + super::Method::Get => Method::Get, + super::Method::Head => Method::Head, + super::Method::Post => Method::Post, + super::Method::Put => Method::Put, + super::Method::Connect => Method::Connect, + super::Method::Options => Method::Options, + super::Method::Trace => Method::Trace, + super::Method::Copy => Method::Copy, + super::Method::Lock => Method::Lock, + super::Method::MkCol => Method::MkCol, + super::Method::Move => Method::Move, + super::Method::Propfind => Method::Propfind, + super::Method::Proppatch => Method::Proppatch, + super::Method::Search => Method::Search, + super::Method::Unlock => Method::Unlock, + super::Method::Bind => Method::Bind, + super::Method::Rebind => Method::Rebind, + super::Method::Unbind => Method::Unbind, + super::Method::Acl => Method::Acl, + super::Method::Report => Method::Report, + super::Method::MkActivity => Method::MkActivity, + super::Method::Checkout => Method::Checkout, + super::Method::Merge => Method::Merge, + super::Method::MSearch => Method::MSearch, + super::Method::Notify => Method::Notify, + super::Method::Subscribe => Method::Subscribe, + super::Method::Unsubscribe => Method::Unsubscribe, + super::Method::Patch => Method::Patch, + super::Method::Purge => Method::Purge, + super::Method::MkCalendar => Method::MkCalendar, + super::Method::Link => Method::Link, + super::Method::Unlink => Method::Unlink, + } + } + } + + impl<'b, const N: usize> embedded_svc::http::Query for super::RequestHeaders<'b, N> { + fn uri(&self) -> &'_ str { + self.path.unwrap_or("") + } + + fn method(&self) -> Method { + self.method.unwrap_or(super::Method::Get).into() + } + } + + impl<'b, const N: usize> embedded_svc::http::Headers for super::RequestHeaders<'b, N> { + fn header(&self, name: &str) -> Option<&'_ str> { + self.headers.get(name) + } + } + + impl<'b, const N: usize> embedded_svc::http::Status for super::ResponseHeaders<'b, N> { + fn status(&self) -> u16 { + self.code.unwrap_or(200) + } + + fn status_message(&self) -> Option<&'_ str> { + self.reason + } + } + + impl<'b, const N: usize> embedded_svc::http::Headers for super::ResponseHeaders<'b, N> { + fn header(&self, name: &str) -> Option<&'_ str> { + self.headers.get(name) + } + } + + impl<'b, const N: usize> embedded_svc::http::Headers for super::Headers<'b, N> { + fn header(&self, name: &str) -> Option<&'_ str> { + self.get(name) + } + } +} diff --git a/src/asynch/http/server.rs b/edge-http/src/server.rs similarity index 99% rename from src/asynch/http/server.rs rename to edge-http/src/server.rs index c8dd8cf..8dc4196 100644 --- a/src/asynch/http/server.rs +++ b/edge-http/src/server.rs @@ -7,7 +7,7 @@ use embedded_io_async::{Read, Write}; use log::{info, warn}; -use crate::asynch::http::{ +use crate::{ send_headers, send_headers_end, send_status, Body, BodyType, Error, Method, RequestHeaders, SendBody, }; @@ -336,7 +336,7 @@ pub struct Server { impl Server where - A: crate::asynch::tcp::TcpAccept, + A: edge_tcp::TcpAccept, H: for<'b, 't> Handler<'b, N, &'b mut A::Connection<'t>>, { pub const fn new(acceptor: A, handler: H) -> Self { diff --git a/edge-mdns/Cargo.toml b/edge-mdns/Cargo.toml new file mode 100644 index 0000000..d1e1833 --- /dev/null +++ b/edge-mdns/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "edge-mdns" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +log.workspace = true +heapless.workspace = true + +domain = { version = "0.7", default-features = false } diff --git a/src/mdns.rs b/edge-mdns/src/lib.rs similarity index 99% rename from src/mdns.rs rename to edge-mdns/src/lib.rs index 6fac2c7..df8324a 100644 --- a/src/mdns.rs +++ b/edge-mdns/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "std"), no_std)] + use core::fmt::Write; use core::str::FromStr; diff --git a/edge-mqtt/Cargo.toml b/edge-mqtt/Cargo.toml new file mode 100644 index 0000000..db6c145 --- /dev/null +++ b/edge-mqtt/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "edge-mqtt" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +std = [] + +[dependencies] +rumqttc = { version = "0.19" } diff --git a/src/asynch/rumqttc.rs b/edge-mqtt/src/lib.rs similarity index 99% rename from src/asynch/rumqttc.rs rename to edge-mqtt/src/lib.rs index 882d294..f8c6068 100644 --- a/src/asynch/rumqttc.rs +++ b/edge-mqtt/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "std"), no_std)] + pub use rumqttc::*; #[cfg(feature = "embedded-svc")] diff --git a/edge-tcp/Cargo.toml b/edge-tcp/Cargo.toml new file mode 100644 index 0000000..a76ddf8 --- /dev/null +++ b/edge-tcp/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "edge-tcp" +version = "0.1.0" +authors = ["Ivan Markov "] +edition = "2021" +categories = ["embedded", "hardware-support"] +keywords = ["embedded", "svc", "network"] +description = "TCP traits for edge-net" +repository = "https://github.com/ivmarkov/edge-net" +license = "MIT OR Apache-2.0" +readme = "README.md" +rust-version = "1.71" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +embedded-io.workspace = true +embedded-io-async.workspace = true +no-std-net.workspace = true \ No newline at end of file diff --git a/edge-tcp/README.md b/edge-tcp/README.md new file mode 100644 index 0000000..7cef1ed --- /dev/null +++ b/edge-tcp/README.md @@ -0,0 +1 @@ +# TCP traits for `edge-net` diff --git a/src/asynch/tcp.rs b/edge-tcp/src/lib.rs similarity index 95% rename from src/asynch/tcp.rs rename to edge-tcp/src/lib.rs index 5e3fc0b..3f9538d 100644 --- a/src/asynch/tcp.rs +++ b/edge-tcp/src/lib.rs @@ -1,5 +1,12 @@ -use core::fmt::Debug; +#![no_std] +#![allow(stable_features)] +#![allow(unknown_lints)] +#![feature(async_fn_in_trait)] +#![allow(async_fn_in_trait)] +#![feature(impl_trait_projections)] +#![feature(impl_trait_in_assoc_type)] +use core::fmt::Debug; use no_std_net::SocketAddr; pub trait TcpSplittableConnection { diff --git a/edge-ws/Cargo.toml b/edge-ws/Cargo.toml new file mode 100644 index 0000000..dde441c --- /dev/null +++ b/edge-ws/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "edge-ws" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +edge-http.workspace = true +embedded-io-async.workspace = true +embedded-nal-async.workspace = true + +base64 = { version = "0.13", default-features = false } +sha1_smol = { version = "1", default-features = false } \ No newline at end of file diff --git a/src/asynch/ws.rs b/edge-ws/src/lib.rs similarity index 89% rename from src/asynch/ws.rs rename to edge-ws/src/lib.rs index 21a49bc..dbeb6d9 100644 --- a/src/asynch/ws.rs +++ b/edge-ws/src/lib.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "std"), no_std)] + use core::cmp::min; use embedded_io_async::{Read, ReadExactError, Write}; @@ -379,6 +381,8 @@ where } pub mod http { + use edge_http::Headers; + pub const NONCE_LEN: usize = 16; pub const MAX_BASE64_KEY_LEN: usize = 28; pub const MAX_BASE64_KEY_RESPONSE_LEN: usize = 33; @@ -505,7 +509,7 @@ pub mod http { pub mod client { use embedded_nal_async::TcpConnect; - use crate::asynch::http::{client::ClientConnection, Error, Method}; + use edge_http::{client::ClientConnection, Error, Method}; use super::{upgrade_request_headers, MAX_BASE64_KEY_LEN, NONCE_LEN}; @@ -555,6 +559,71 @@ pub mod http { Ok(succeeded) } } + + pub trait HeaderExt<'b> { + fn is_ws_upgrade_request(&self) -> bool; + + fn set_ws_upgrade_request_headers( + &mut self, + host: Option<&'b str>, + origin: Option<&'b str>, + version: Option<&'b str>, + nonce: &[u8; crate::http::NONCE_LEN], + nonce_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_LEN], + ) -> &mut Self; + + fn set_ws_upgrade_response_headers<'a, H>( + &mut self, + request_headers: H, + version: Option<&'a str>, + sec_key_response_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_RESPONSE_LEN], + ) -> Result<&mut Self, UpgradeError> + where + H: IntoIterator; + } + + impl<'b, const N: usize> HeaderExt<'b> for Headers<'b, N> { + fn is_ws_upgrade_request(&self) -> bool { + crate::http::is_upgrade_request(self.iter()) + } + + fn set_ws_upgrade_request_headers( + &mut self, + host: Option<&'b str>, + origin: Option<&'b str>, + version: Option<&'b str>, + nonce: &[u8; crate::http::NONCE_LEN], + nonce_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_LEN], + ) -> &mut Self { + for (name, value) in + crate::http::upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf) + { + self.set(name, value); + } + + self + } + + fn set_ws_upgrade_response_headers<'a, H>( + &mut self, + request_headers: H, + version: Option<&'a str>, + sec_key_response_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_RESPONSE_LEN], + ) -> Result<&mut Self, UpgradeError> + where + H: IntoIterator, + { + for (name, value) in crate::http::upgrade_response_headers( + request_headers, + version, + sec_key_response_base64_buf, + )? { + self.set(name, value); + } + + Ok(self) + } + } } #[cfg(feature = "embedded-svc")] diff --git a/src/asynch.rs b/src/asynch.rs index bff150e..b08d473 100644 --- a/src/asynch.rs +++ b/src/asynch.rs @@ -1,64 +1,50 @@ -pub mod dhcp; -pub mod http; -pub mod io; -#[cfg(all(feature = "std", feature = "rumqttc"))] -pub mod rumqttc; -#[cfg(feature = "std")] -pub mod stdnal; -pub mod tcp; -pub mod ws; - -pub use unblocker::Unblocker; - #[cfg(feature = "embedded-svc")] pub use embedded_svc_compat::*; -mod unblocker { - use core::future::Future; - - pub trait Unblocker { - type UnblockFuture<'a, F, T>: Future + Send - where - Self: 'a, - F: Send + 'a, - T: Send + 'a; +use core::future::Future; - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a; - } +pub trait Unblocker { + type UnblockFuture<'a, F, T>: Future + Send + where + Self: 'a, + F: Send + 'a, + T: Send + 'a; - impl Unblocker for &U + fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> where - U: Unblocker, - { - type UnblockFuture<'a, F, T> - = U::UnblockFuture<'a, F, T> where Self: 'a, F: Send + 'a, T: Send + 'a; + F: FnOnce() -> T + Send + 'a, + T: Send + 'a; +} - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - (*self).unblock(f) - } - } +impl Unblocker for &U +where + U: Unblocker, +{ + type UnblockFuture<'a, F, T> + = U::UnblockFuture<'a, F, T> where Self: 'a, F: Send + 'a, T: Send + 'a; - impl Unblocker for &mut U + fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> where - U: Unblocker, + F: FnOnce() -> T + Send + 'a, + T: Send + 'a, { - type UnblockFuture<'a, F, T> + (*self).unblock(f) + } +} + +impl Unblocker for &mut U +where + U: Unblocker, +{ + type UnblockFuture<'a, F, T> = U::UnblockFuture<'a, F, T> where Self: 'a, F: Send + 'a, T: Send + 'a; - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - (**self).unblock(f) - } + fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> + where + F: FnOnce() -> T + Send + 'a, + T: Send + 'a, + { + (**self).unblock(f) } } diff --git a/src/asynch/io.rs b/src/asynch/io.rs deleted file mode 100644 index 7d1f804..0000000 --- a/src/asynch/io.rs +++ /dev/null @@ -1,102 +0,0 @@ -use embedded_io::Error; -use embedded_io_async::{Read, Write}; - -pub async fn try_read_full( - mut read: R, - buf: &mut [u8], -) -> Result { - let mut offset = 0; - let mut size = 0; - - loop { - let size_read = read.read(&mut buf[offset..]).await.map_err(|e| (e, size))?; - - offset += size_read; - size += size_read; - - if size_read == 0 || size == buf.len() { - break; - } - } - - Ok(size) -} - -#[derive(Debug)] -pub enum CopyError { - Read(R), - Write(W), -} - -impl Error for CopyError -where - R: Error, - W: Error, -{ - fn kind(&self) -> embedded_io::ErrorKind { - match self { - Self::Read(e) => e.kind(), - Self::Write(e) => e.kind(), - } - } -} - -pub async fn copy( - read: R, - write: W, -) -> Result> -where - R: Read, - W: Write, -{ - copy_len::(read, write, u64::MAX).await -} - -pub async fn copy_len( - read: R, - write: W, - len: u64, -) -> Result> -where - R: Read, - W: Write, -{ - copy_len_with_progress::(read, write, len, |_, _| {}).await -} - -pub async fn copy_len_with_progress( - mut read: R, - mut write: W, - mut len: u64, - progress: P, -) -> Result> -where - R: Read, - W: Write, - P: Fn(u64, u64), -{ - let mut buf = [0_u8; N]; - - let mut copied = 0; - - while len > 0 { - progress(copied, len); - - let size_read = read.read(&mut buf).await.map_err(CopyError::Read)?; - if size_read == 0 { - break; - } - - write - .write_all(&buf[0..size_read]) - .await - .map_err(CopyError::Write)?; - - copied += size_read as u64; - len -= size_read as u64; - } - - progress(copied, len); - - Ok(copied) -} diff --git a/src/captive.rs b/src/captive.rs deleted file mode 100644 index 3ff4109..0000000 --- a/src/captive.rs +++ /dev/null @@ -1,259 +0,0 @@ -use core::fmt; -use core::time::Duration; - -use log::debug; - -use domain::{ - base::{ - iana::{Class, Opcode, Rcode}, - octets::*, - Record, Rtype, - }, - rdata::A, -}; - -#[cfg(feature = "std")] -pub use server::*; - -#[derive(Debug)] -pub struct InnerError(T); - -#[derive(Debug)] -pub enum DnsError { - ShortBuf(InnerError), - ParseError(InnerError), -} - -impl fmt::Display for DnsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - DnsError::ShortBuf(e) => e.0.fmt(f), - DnsError::ParseError(e) => e.0.fmt(f), - } - } -} - -impl From for DnsError { - fn from(e: ShortBuf) -> Self { - Self::ShortBuf(InnerError(e)) - } -} - -impl From for DnsError { - fn from(e: ParseError) -> Self { - Self::ParseError(InnerError(e)) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for DnsError {} - -pub fn process_dns_request( - request: impl AsRef<[u8]>, - ip: &[u8; 4], - ttl: Duration, -) -> Result, DnsError> { - let request = request.as_ref(); - let response = Octets512::new(); - - let message = domain::base::Message::from_octets(request)?; - debug!("Processing message with header: {:?}", message.header()); - - let mut responseb = domain::base::MessageBuilder::from_target(response)?; - - let response = if matches!(message.header().opcode(), Opcode::Query) { - debug!("Message is of type Query, processing all questions"); - - let mut answerb = responseb.start_answer(&message, Rcode::NoError)?; - - for question in message.question() { - let question = question?; - - if matches!(question.qtype(), Rtype::A) { - debug!( - "Question {:?} is of type A, answering with IP {:?}, TTL {:?}", - question, ip, ttl - ); - - let record = Record::new( - question.qname(), - Class::In, - ttl.as_secs() as u32, - A::from_octets(ip[0], ip[1], ip[2], ip[3]), - ); - debug!("Answering question {:?} with {:?}", question, record); - answerb.push(record)?; - } else { - debug!("Question {:?} is not of type A, not answering", question); - } - } - - answerb.finish() - } else { - debug!("Message is not of type Query, replying with NotImp"); - - let headerb = responseb.header_mut(); - - headerb.set_id(message.header().id()); - headerb.set_opcode(message.header().opcode()); - headerb.set_rd(message.header().rd()); - headerb.set_rcode(domain::base::iana::Rcode::NotImp); - - responseb.finish() - }; - Ok(response) -} - -#[cfg(feature = "std")] -mod server { - use std::{ - io, mem, - net::{Ipv4Addr, SocketAddrV4, UdpSocket}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread::{self, JoinHandle}, - time::Duration, - }; - - use log::*; - - #[derive(Clone, Debug)] - pub struct DnsConf { - pub bind_ip: Ipv4Addr, - pub bind_port: u16, - pub ip: Ipv4Addr, - pub ttl: Duration, - } - - impl DnsConf { - pub fn new(ip: Ipv4Addr) -> Self { - Self { - bind_ip: Ipv4Addr::new(0, 0, 0, 0), - bind_port: 53, - ip, - ttl: Duration::from_secs(60), - } - } - } - - #[derive(Debug)] - pub enum Status { - Stopped, - Started, - Error(io::Error), - } - - pub struct DnsServer { - conf: DnsConf, - status: Status, - running: Arc, - handle: Option>>, - } - - impl DnsServer { - pub fn new(conf: DnsConf) -> Self { - Self { - conf, - status: Status::Stopped, - running: Arc::new(AtomicBool::new(false)), - handle: None, - } - } - - pub fn get_status(&mut self) -> &Status { - self.cleanup(); - &self.status - } - - pub fn start(&mut self) -> Result<(), io::Error> { - if matches!(self.get_status(), Status::Started) { - return Ok(()); - } - let socket_address = SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port); - let running = self.running.clone(); - let ip = self.conf.ip; - let ttl = self.conf.ttl; - - self.running.store(true, Ordering::Relaxed); - self.handle = Some( - thread::Builder::new() - // default stack size is not enough - // 9000 was found via trial and error - .stack_size(9000) - .spawn(move || { - // Socket is not movable across thread bounds - // Otherwise we run into an assertion error here: https://github.com/espressif/esp-idf/blob/master/components/lwip/port/esp32/freertos/sys_arch.c#L103 - let socket = UdpSocket::bind(socket_address)?; - socket.set_read_timeout(Some(Duration::from_secs(1)))?; - let result = Self::run(&running, ip, ttl, socket); - - running.store(false, Ordering::Relaxed); - - result - }) - .unwrap(), - ); - - Ok(()) - } - - pub fn stop(&mut self) -> Result<(), io::Error> { - if matches!(self.get_status(), Status::Stopped) { - return Ok(()); - } - - self.running.store(false, Ordering::Relaxed); - self.cleanup(); - - let mut status = Status::Stopped; - mem::swap(&mut self.status, &mut status); - - match status { - Status::Error(e) => Err(e), - _ => Ok(()), - } - } - - fn cleanup(&mut self) { - if !self.running.load(Ordering::Relaxed) && self.handle.is_some() { - self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() { - Ok(_) => Status::Stopped, - Err(e) => Status::Error(e), - }; - } - } - - fn run( - running: &AtomicBool, - ip: Ipv4Addr, - ttl: Duration, - socket: UdpSocket, - ) -> Result<(), io::Error> { - while running.load(Ordering::Relaxed) { - let mut request_arr = [0_u8; 512]; - debug!("Waiting for data"); - let (request_len, source_addr) = match socket.recv_from(&mut request_arr) { - Ok(value) => value, - Err(err) => match err.kind() { - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue, - _ => return Err(err), - }, - }; - - let request = &request_arr[..request_len]; - - debug!("Received {} bytes from {}", request.len(), source_addr); - let response = super::process_dns_request(request, &ip.octets(), ttl) - .map_err(|_| io::ErrorKind::Other)?; - - socket.send_to(response.as_ref(), source_addr)?; - - debug!("Sent {} bytes to {}", response.as_ref().len(), source_addr); - } - - Ok(()) - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 75dbdd8..49ef4de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,23 @@ #![cfg_attr(not(feature = "std"), no_std)] #![allow(stable_features)] -#![allow(unknown_lints)] -#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] -#![cfg_attr(feature = "nightly", allow(async_fn_in_trait))] -#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] -#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))] +#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))] // Used in Unblocker + +// Re-export enabled sub-crates +#[cfg(feature = "edge-captive")] +pub use edge_captive as captive; +#[cfg(feature = "edge-dhcp")] +pub use edge_dhcp as dhcp; +#[cfg(feature = "edge-http")] +pub use edge_http as http; +#[cfg(feature = "edge-mdns")] +pub use edge_mdns as mdns; +#[cfg(feature = "edge-mqtt")] +pub use edge_mqtt as mqtt; +#[cfg(feature = "edge-ws")] +pub use edge_ws as ws; #[cfg(feature = "nightly")] pub mod asynch; -#[cfg(feature = "domain")] -pub mod captive; -pub mod dhcp; -#[cfg(feature = "domain")] -pub mod mdns; + #[cfg(feature = "std")] -pub mod std_mutex; +pub mod std; diff --git a/src/std_mutex.rs b/src/std.rs similarity index 89% rename from src/std_mutex.rs rename to src/std.rs index 5069dc5..8061cf0 100644 --- a/src/std_mutex.rs +++ b/src/std.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "nightly")] +pub mod nal; + use embassy_sync::blocking_mutex::raw::RawMutex; pub struct StdRawMutex(std::sync::Mutex<()>); diff --git a/src/asynch/stdnal.rs b/src/std/nal.rs similarity index 98% rename from src/asynch/stdnal.rs rename to src/std/nal.rs index 1f08c66..4c969b8 100644 --- a/src/asynch/stdnal.rs +++ b/src/std/nal.rs @@ -5,16 +5,14 @@ use std::os::fd::{AsFd, AsRawFd}; use async_io::Async; use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; -use embedded_io::ErrorType; -use embedded_io_async::{Read, Write}; +use embedded_io_async::{ErrorType, Read, Write}; +use no_std_net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use embedded_nal_async::{ AddrType, ConnectedUdp, Dns, IpAddr, TcpConnect, UdpStack, UnconnectedUdp, }; -use no_std_net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; - -use super::tcp::{RawSocket, RawStack, TcpAccept, TcpListen, TcpSplittableConnection}; +use edge_tcp::{RawSocket, RawStack, TcpAccept, TcpListen, TcpSplittableConnection}; pub struct StdTcpConnect(()); diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 4e570cc..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[cfg(all(feature = "embedded-svc", feature = "nightly"))] -pub mod ghota; -pub mod io; -pub mod json_io;