Skip to content

Commit

Permalink
Split up edge-net into smaller crates (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani authored Dec 5, 2023
1 parent 38a6f97 commit 4a80e47
Show file tree
Hide file tree
Showing 28 changed files with 2,126 additions and 461 deletions.
88 changes: 65 additions & 23 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,50 @@ readme = "README.md"
rust-version = "1.71"

[features]
default = ["std"]
std = [
"alloc",
"embassy-sync/std",
"embedded-svc?/std",
"edge-http?/std",
"edge-captive?/std",
"dep:edge-tcp",

std = ["alloc", "embedded-io/std", "embassy-sync/std", "embedded-svc?/std", "async-io", "libc", "futures-lite/std", "httparse/std", "domain?/std"]
alloc = ["embedded-io/alloc", "embedded-svc?/alloc", "embedded-io-async?/alloc", ]
nightly = ["embedded-io-async", "embassy-sync/nightly", "embedded-svc?/nightly"]
embassy-util = [] # Just for backward compat. To be removed in future
"futures-lite/std",
"dep:async-io",
"dep:embedded-io-async",
"dep:embedded-nal-async",
"dep:libc"
]
alloc = ["edge-http?/alloc", "embedded-svc?/alloc"]
nightly = ["dep:edge-http", "embedded-svc?/nightly"]

[dependencies]
embedded-io = { version = "0.6", default-features = false }
embedded-io-async = { version = "0.6", default-features = false, optional = true }
heapless = { version = "0.7", default-features = false }
httparse = { version = "1.7", default-features = false }
num_enum = { version = "0.7", default-features = false }
embassy-futures = "0.1"
embassy-time = "0.1"
embassy-sync = "0.3"
no-std-net = { version = "0.6", default-features = false }
rand_core = "0.6"
log = { version = "0.4", default-features = false }
base64 = { version = "0.13", default-features = false }
sha1_smol = { version = "1", default-features = false }
embedded-nal-async = "0.6"
domain = { version = "0.7", default-features = false, optional = true }
embedded-svc = { version = "0.26", default-features = false, optional = true, features = ["embedded-io-async"] }

futures-lite = { version = "1", default-features = false, optional = true }
async-io = { version = "2", default-features = false, optional = true }
embedded-io-async = { workspace = true, optional = true }
embedded-nal-async = { workspace = true, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
embedded-svc = { version = "0.26", default-features = false, optional = true, features = ["embedded-io-async"] }
rumqttc = { version = "0.19", optional = true }

embassy-sync.workspace = true
log.workspace = true
heapless.workspace = true
no-std-net.workspace = true

edge-captive = { workspace = true, optional = true }
edge-dhcp = { workspace = true, optional = true }
edge-http = { workspace = true, optional = true }
edge-mdns = { workspace = true, optional = true }
edge-mqtt = { workspace = true, optional = true }
edge-ws = { workspace = true, optional = true }

# edge-tcp is an exception that only contains trait definitions
edge-tcp = { workspace = true, optional = true }

[dev-dependencies]
anyhow = "1"
simple_logger="2.2"
simple_logger = "2.2"

[[example]]
name = "captive_portal"
Expand All @@ -60,3 +72,33 @@ required-features = ["std", "nightly"]
[[example]]
name = "http_server"
required-features = ["std", "nightly"]

[workspace]
members = [
".",
"edge-captive",
"edge-dhcp",
"edge-http",
"edge-mdns",
"edge-mqtt",
"edge-tcp",
"edge-ws"
]

[workspace.dependencies]
embassy-futures = "0.1"
embassy-sync = "0.3"
embedded-io = { version = "0.6", default-features = false }
embedded-io-async = { version = "0.6", default-features = false }
embedded-nal-async = "0.6"
log = { version = "0.4", default-features = false }
heapless = { version = "0.7", default-features = false }
no-std-net = { version = "0.6", default-features = false }

edge-captive = { version = "0.1.0", path = "edge-captive" }
edge-dhcp = { version = "0.1.0", path = "edge-dhcp" }
edge-http = { version = "0.1.0", path = "edge-http" }
edge-mdns = { version = "0.1.0", path = "edge-mdns" }
edge-mqtt = { version = "0.1.0", path = "edge-mqtt" }
edge-tcp = { version = "0.1.0", path = "edge-tcp" }
edge-ws = { version = "0.1.0", path = "edge-ws" }
14 changes: 14 additions & 0 deletions edge-captive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "edge-captive"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
std = ["domain/std"]

[dependencies]
log.workspace = true

domain = { version = "0.7", default-features = false }
110 changes: 110 additions & 0 deletions edge-captive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#![cfg_attr(not(feature = "std"), no_std)]

use core::fmt;
use core::time::Duration;

use log::debug;

use domain::{
base::{
iana::{Class, Opcode, Rcode},
octets::*,
Record, Rtype,
},
rdata::A,
};

#[cfg(feature = "std")]
mod server;

#[cfg(feature = "std")]
pub use server::*;

#[derive(Debug)]
pub struct InnerError<T: fmt::Debug + fmt::Display>(T);

#[derive(Debug)]
pub enum DnsError {
ShortBuf(InnerError<ShortBuf>),
ParseError(InnerError<ParseError>),
}

impl fmt::Display for DnsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DnsError::ShortBuf(e) => e.0.fmt(f),
DnsError::ParseError(e) => e.0.fmt(f),
}
}
}

impl From<ShortBuf> for DnsError {
fn from(e: ShortBuf) -> Self {
Self::ShortBuf(InnerError(e))
}
}

impl From<ParseError> for DnsError {
fn from(e: ParseError) -> Self {
Self::ParseError(InnerError(e))
}
}

#[cfg(feature = "std")]
impl std::error::Error for DnsError {}

pub fn process_dns_request(
request: impl AsRef<[u8]>,
ip: &[u8; 4],
ttl: Duration,
) -> Result<impl AsRef<[u8]>, DnsError> {
let request = request.as_ref();
let response = Octets512::new();

let message = domain::base::Message::from_octets(request)?;
debug!("Processing message with header: {:?}", message.header());

let mut responseb = domain::base::MessageBuilder::from_target(response)?;

let response = if matches!(message.header().opcode(), Opcode::Query) {
debug!("Message is of type Query, processing all questions");

let mut answerb = responseb.start_answer(&message, Rcode::NoError)?;

for question in message.question() {
let question = question?;

if matches!(question.qtype(), Rtype::A) {
debug!(
"Question {:?} is of type A, answering with IP {:?}, TTL {:?}",
question, ip, ttl
);

let record = Record::new(
question.qname(),
Class::In,
ttl.as_secs() as u32,
A::from_octets(ip[0], ip[1], ip[2], ip[3]),
);
debug!("Answering question {:?} with {:?}", question, record);
answerb.push(record)?;
} else {
debug!("Question {:?} is not of type A, not answering", question);
}
}

answerb.finish()
} else {
debug!("Message is not of type Query, replying with NotImp");

let headerb = responseb.header_mut();

headerb.set_id(message.header().id());
headerb.set_opcode(message.header().opcode());
headerb.set_rd(message.header().rd());
headerb.set_rcode(domain::base::iana::Rcode::NotImp);

responseb.finish()
};
Ok(response)
}
150 changes: 150 additions & 0 deletions edge-captive/src/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use std::{
io, mem,
net::{Ipv4Addr, SocketAddrV4, UdpSocket},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
thread::{self, JoinHandle},
time::Duration,
};

use log::*;

#[derive(Clone, Debug)]
pub struct DnsConf {
pub bind_ip: Ipv4Addr,
pub bind_port: u16,
pub ip: Ipv4Addr,
pub ttl: Duration,
}

impl DnsConf {
pub fn new(ip: Ipv4Addr) -> Self {
Self {
bind_ip: Ipv4Addr::new(0, 0, 0, 0),
bind_port: 53,
ip,
ttl: Duration::from_secs(60),
}
}
}

#[derive(Debug)]
pub enum Status {
Stopped,
Started,
Error(io::Error),
}

pub struct DnsServer {
conf: DnsConf,
status: Status,
running: Arc<AtomicBool>,
handle: Option<JoinHandle<Result<(), io::Error>>>,
}

impl DnsServer {
pub fn new(conf: DnsConf) -> Self {
Self {
conf,
status: Status::Stopped,
running: Arc::new(AtomicBool::new(false)),
handle: None,
}
}

pub fn get_status(&mut self) -> &Status {
self.cleanup();
&self.status
}

pub fn start(&mut self) -> Result<(), io::Error> {
if matches!(self.get_status(), Status::Started) {
return Ok(());
}
let socket_address = SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port);
let running = self.running.clone();
let ip = self.conf.ip;
let ttl = self.conf.ttl;

self.running.store(true, Ordering::Relaxed);
self.handle = Some(
thread::Builder::new()
// default stack size is not enough
// 9000 was found via trial and error
.stack_size(9000)
.spawn(move || {
// Socket is not movable across thread bounds
// Otherwise we run into an assertion error here: https://github.com/espressif/esp-idf/blob/master/components/lwip/port/esp32/freertos/sys_arch.c#L103
let socket = UdpSocket::bind(socket_address)?;
socket.set_read_timeout(Some(Duration::from_secs(1)))?;
let result = Self::run(&running, ip, ttl, socket);

running.store(false, Ordering::Relaxed);

result
})
.unwrap(),
);

Ok(())
}

pub fn stop(&mut self) -> Result<(), io::Error> {
if matches!(self.get_status(), Status::Stopped) {
return Ok(());
}

self.running.store(false, Ordering::Relaxed);
self.cleanup();

let mut status = Status::Stopped;
mem::swap(&mut self.status, &mut status);

match status {
Status::Error(e) => Err(e),
_ => Ok(()),
}
}

fn cleanup(&mut self) {
if !self.running.load(Ordering::Relaxed) && self.handle.is_some() {
self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() {
Ok(_) => Status::Stopped,
Err(e) => Status::Error(e),
};
}
}

fn run(
running: &AtomicBool,
ip: Ipv4Addr,
ttl: Duration,
socket: UdpSocket,
) -> Result<(), io::Error> {
while running.load(Ordering::Relaxed) {
let mut request_arr = [0_u8; 512];
debug!("Waiting for data");
let (request_len, source_addr) = match socket.recv_from(&mut request_arr) {
Ok(value) => value,
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue,
_ => return Err(err),
},
};

let request = &request_arr[..request_len];

debug!("Received {} bytes from {}", request.len(), source_addr);
let response = super::process_dns_request(request, &ip.octets(), ttl)
.map_err(|_| io::ErrorKind::Other)?;

socket.send_to(response.as_ref(), source_addr)?;

debug!("Sent {} bytes to {}", response.as_ref().len(), source_addr);
}

Ok(())
}
}
Loading

0 comments on commit 4a80e47

Please sign in to comment.