From fa23757e6cb08e21a66ba341b4a56cd3ff707838 Mon Sep 17 00:00:00 2001 From: Kieran Moy Date: Sat, 3 Aug 2024 17:03:44 +0800 Subject: [PATCH] support config hot reload --- Cargo.lock | 31 ++++ Cargo.toml | 2 + README.md | 16 +- config.yaml | 2 +- src/config.rs | 6 +- src/init.rs | 139 ++++++++++++++++ src/lib.rs | 2 + src/main.rs | 435 +------------------------------------------------- src/route.rs | 3 +- src/zest.rs | 384 ++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 575 insertions(+), 445 deletions(-) create mode 100644 src/init.rs create mode 100644 src/zest.rs diff --git a/Cargo.lock b/Cargo.lock index 9dda97c..31a1297 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,6 +120,16 @@ dependencies = [ "event-listener", ] +[[package]] +name = "async-rwlock" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261803dcc39ba9e72760ba6e16d0199b1eef9fc44e81bffabbebb9f5aea3906c" +dependencies = [ + "async-mutex", + "event-listener", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -758,6 +768,25 @@ dependencies = [ "tempfile", ] +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "smallvec" version = "1.13.2" @@ -1184,6 +1213,7 @@ name = "zest" version = "0.1.8" dependencies = [ "async-mutex", + "async-rwlock", "chrono", "clap", "ipnet", @@ -1195,6 +1225,7 @@ dependencies = [ "mime_guess", "serde", "serde_yml", + "signal-hook", "tokio", "urlencoding", ] diff --git a/Cargo.toml b/Cargo.toml index f5276c8..aaceed2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ log = ["dep:log"] [dependencies] async-mutex = "1.4.0" +async-rwlock = "1.3.0" chrono = { version = "0.4.38", features = ["clock", "now"] } clap = { version = "4.5.7", features = ["derive"] } ipnet = { version = "2.9.0", optional = true } @@ -29,6 +30,7 @@ mime = { version = "0.3.17" } mime_guess = { version = "2.0.4" } serde = { version = "1.0.203", features = ["derive"] } serde_yml = "0.0.10" +signal-hook = "0.3.17" tokio = { version = "1.38.0", features = [ "rt-multi-thread", "fs", diff --git a/README.md b/README.md index d19e267..0bac8a8 100644 --- a/README.md +++ b/README.md @@ -47,16 +47,16 @@ logging: # optional **Benchmark (wrk)** + cargo run --release --no-default-features --features=lru_cache -- -p 8080 ```text -Running 10s test @ http://localhost:8080/ +Running 10s test @ http://localhost:8080 4 threads and 10 connections Thread Stats Avg Stdev Max +/- Stdev - Latency 358.01us 151.88us 2.57ms 72.96% - Req/Sec 3.88k 171.86 5.04k 79.90% - 155645 requests in 10.10s, 117.86MB read - Socket errors: connect 0, read 155644, write 0, timeout 0 -Requests/sec: 15410.34 -Transfer/sec: 11.67MB -wrk http://localhost:8080/ -t 4 -d 10s 1.51s user 11.11s system 124% cpu 10.109 total + Latency 317.11us 168.20us 4.56ms 78.71% + Req/Sec 4.06k 204.86 5.18k 74.19% + 162856 requests in 10.10s, 117.10MB read + Socket errors: connect 0, read 162854, write 0, timeout 0 +Requests/sec: 16125.32 +Transfer/sec: 11.60MB +wrk http://localhost:8080 -t 4 -d 10s 1.59s user 11.45s system 128% cpu 10.110 total ``` + python -m http.server 8080 diff --git a/config.yaml b/config.yaml index 42bf596..0c9c228 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ bind: addr: 0.0.0.0 - listen: 80 + listen: 8081 server: info: "Powered by Rust" diff --git a/src/config.rs b/src/config.rs index 8020a31..e17036e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,4 @@ +use async_rwlock::RwLock; use clap::{command, Parser}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -98,11 +99,12 @@ pub struct Args { lazy_static! { pub static ref CONFIG_PATH: Mutex = Mutex::new("".to_owned()); - pub static ref CONFIG: Config = init_config(); + pub static ref DEFAULT_CONFIG: Config = init_config(); + pub static ref CONFIG: RwLock = RwLock::new((*DEFAULT_CONFIG).clone()); pub static ref ARGS: Args = Args::parse(); } -fn init_config() -> Config { +pub fn init_config() -> Config { let config_path = CONFIG_PATH.lock().unwrap(); let default_config = Config::default(); let mut config = match fs::read_to_string(config_path.to_owned()) { diff --git a/src/init.rs b/src/init.rs new file mode 100644 index 0000000..71951bd --- /dev/null +++ b/src/init.rs @@ -0,0 +1,139 @@ +use crate::{ + config::{init_config, CONFIG}, + zest::T, +}; +use async_mutex::Mutex; +use lazy_static::lazy_static; +use log4rs::Handle; +use signal_hook::{consts::SIGUSR1, iterator::Signals}; +use std::{env::set_current_dir, error::Error}; + +#[cfg(feature = "log")] +use { + crate::config::Config, + log4rs::{ + append::{console::ConsoleAppender, file::FileAppender}, + config::{Appender, Logger, Root}, + encode::pattern::PatternEncoder, + }, + std::{ops::Deref, path::Path}, +}; + +#[cfg(feature = "log")] +const LOG_FORMAT: &str = "[{d(%Y-%m-%dT%H:%M:%SZ)} {h({l})} zest] {m}\n"; + +lazy_static! { + pub static ref LOGGER_HANDLE: Mutex> = Mutex::new(None); +} + +#[cfg(feature = "log")] +pub async fn build_logger_config(config: C) -> log4rs::Config +where + C: Deref, +{ + let mut builder = log4rs::Config::builder(); + + let stdout = ConsoleAppender::builder() + .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) + .target(log4rs::append::console::Target::Stdout) + .build(); + + let stderr = ConsoleAppender::builder() + .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) + .target(log4rs::append::console::Target::Stderr) + .build(); + + let logging = &config.logging.clone().unwrap_or_default(); + builder = if let Some(access_log) = &logging.access_log { + let access_log_path = Path::new(&access_log); + std::fs::File::create(access_log_path).unwrap(); + builder.appender( + Appender::builder().build( + "logfile_access", + Box::new( + FileAppender::builder() + .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) + .build(access_log_path) + .unwrap(), + ), + ), + ) + } else { + builder.appender(Appender::builder().build("logfile_access", Box::new(stdout))) + }; + + builder = if let Some(error_log) = &logging.error_log { + let error_log_path = Path::new(&error_log); + std::fs::File::create(error_log_path).unwrap(); + builder.appender( + Appender::builder().build( + "logfile_error", + Box::new( + FileAppender::builder() + .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) + .build(error_log_path) + .unwrap(), + ), + ), + ) + } else { + builder.appender(Appender::builder().build("logfile_error", Box::new(stderr))) + }; + + builder + .logger( + Logger::builder() + .appender("logfile_access") + .additive(false) + .build("access", log::LevelFilter::Info), + ) + .logger( + Logger::builder() + .appender("logfile_error") + .additive(false) + .build("error", log::LevelFilter::Error), + ) + .build(Root::builder().build(log::LevelFilter::Off)) + .unwrap() +} + +#[cfg(feature = "log")] +pub async fn init_logger(config: C) +where + C: Deref, +{ + let config = build_logger_config(config).await; + *LOGGER_HANDLE.lock().await = Some(log4rs::init_config(config).unwrap()) +} + +pub async fn init_signal() -> Result<(), Box> { + let mut signals = Signals::new([SIGUSR1])?; + + tokio::spawn(async move { + for sig in signals.forever() { + if sig == SIGUSR1 { + let config: crate::config::Config = init_config(); + + let mut _c = CONFIG.try_write().unwrap(); + *_c = config.clone(); + drop(_c); + + set_current_dir(config.clone().server.root).unwrap(); + + #[cfg(feature = "log")] + { + let mut _handle = LOGGER_HANDLE.lock().await; + if let Some(handle) = _handle.take() { + handle.set_config(build_logger_config(&config.clone()).await); + } + } + + let mut t = T.try_write().unwrap(); + *t = None; + drop(t); + } + } + }); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 40bef0e..28166de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ pub mod config; +pub mod init; pub mod route; +pub mod zest; #[cfg(test)] mod tests { diff --git a/src/main.rs b/src/main.rs index a463d08..89f2874 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,435 +1,6 @@ -use zest::{ - config::{ARGS, CONFIG, CONFIG_PATH}, - route::{location_index, mime_match, root_relative, status_page}, -}; - -use chrono::{DateTime, Utc}; -use mime::Mime; -use std::{ - collections::HashMap, env::set_current_dir, error::Error, io, ops::Deref, path::Path, sync::Arc, -}; - -#[cfg(feature = "log")] -use { - log::logger, - log4rs::{ - append::{console::ConsoleAppender, file::FileAppender}, - config::{Appender, Logger, Root}, - encode::pattern::PatternEncoder, - }, -}; - -#[cfg(target_os = "android")] -use std::os::android::fs::MetadataExt; - -#[cfg(target_os = "linux")] -use std::os::linux::fs::MetadataExt; - -#[cfg(feature = "lru_cache")] -use { - async_mutex::Mutex, // faster than tokio::sync::Mutex - lazy_static::lazy_static, - lru::LruCache, - std::num::NonZeroUsize, -}; - -use tokio::{ - fs::File, - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, - net::TcpListener, - sync::Semaphore, -}; - -const DATE_FORMAT: &str = "%a, %d %b %Y %H:%M:%S GMT"; - -#[cfg(feature = "log")] -const LOG_FORMAT: &str = "[{d(%Y-%m-%dT%H:%M:%SZ)} {h({l})} zest] {m}\n"; - -#[cfg(feature = "lru_cache")] -lazy_static! { - static ref INDEX_CACHE: Mutex> = { - let cache = LruCache::new( - NonZeroUsize::new( - CONFIG - .server - .cache - .clone() - .unwrap_or_default() - .index_capacity - .unwrap_or_default(), - ) - .unwrap(), - ); - Mutex::new(cache) - }; - static ref FILE_CACHE: Mutex>> = { - let cache = LruCache::new( - NonZeroUsize::new( - CONFIG - .server - .cache - .clone() - .unwrap_or_default() - .file_capacity - .unwrap_or_default(), - ) - .unwrap(), - ); - Mutex::new(cache) - }; -} - -#[derive(Clone)] -struct Response<'a> { - version: &'a str, - status_code: i32, - _headers_buffer: HashMap<&'a str, String>, -} - -impl<'a> Response<'a> { - #[inline] - fn send_header(&mut self, k: &'a str, v: T) -> Option - where - T: ToString, - { - self._headers_buffer.insert(k, v.to_string()) - } - #[inline] - fn resp(&mut self) -> String { - let (version, status_code) = (self.version, self.status_code); - let mut resp = format!("HTTP/{} {}\r\n", version, self.status(status_code)); - for (key, value) in &self._headers_buffer { - resp.push_str(&format!("{}: {}\r\n", key, value)); - } - resp.push_str("\r\n"); - resp - } - #[inline] - fn status(&mut self, status_code: i32) -> String { - let status = match status_code { - 200 => "OK", - 301 => "Moved Permanently", - 400 => "Bad Request", - 404 => "Not Found", - 501 => "Not Implemented", - _ => "Internal Server Error", // 500 - }; - - format!("{} {}", status_code, status) - } -} - -async fn handle_connection(mut stream: S) -> io::Result<(i32, String)> -where - S: AsyncReadExt + AsyncWriteExt + Unpin, -{ - let config = CONFIG.deref(); - - let mut response: Response = Response { - version: "1.1", - status_code: 200, - _headers_buffer: HashMap::new(), - }; - - let server_info = format!( - "Zest/{} ({})", - env!("CARGO_PKG_VERSION"), - config.server.info - ); - response.send_header("Server", server_info.clone()); - - response.send_header("Date", Utc::now().format(DATE_FORMAT)); - - let buf_reader = BufReader::new(&mut stream); - let req = buf_reader.lines().next_line().await?.unwrap_or_default(); - - // GET /location HTTP/1.1 - let parts: Vec<&str> = req.split('/').collect(); - - let mut mime_type: Mime = mime::TEXT_HTML_UTF_8; - let mut buffer: Vec = Vec::new(); - - if parts.len() < 3 { - response.status_code = 400; - } else if parts.first().unwrap().trim() != "GET" { - response.status_code = 501; - } else if let Some(location) = &req.split_whitespace().nth(1) { - let location: String = urlencoding::decode(root_relative(location)) - .unwrap_or_default() - .into(); - - response.version = parts.last().unwrap(); - let mut path = config.server.root.join(location.split('?').next().unwrap()); - - path = match path.canonicalize() { - Ok(canonical_path) => canonical_path, - Err(_) => { - response.status_code = 404; - config - .server - .root - .join(Path::new( - &config - .server - .error_page - .clone() - .unwrap_or("404.html".into()), - )) - .to_path_buf() - .canonicalize() - .unwrap_or_default() - } - }; - if path.is_dir() { - #[allow(unused_assignments)] - let mut html: String = String::new(); - #[cfg(feature = "lru_cache")] - { - let mut cache = INDEX_CACHE.lock().await; - if let Some(ctx) = cache.get(&location) { - html.clone_from(ctx); - } else if let Ok(index) = location_index(path, &location).await { - cache - .push(location.clone(), index) - .to_owned() - .unwrap_or_default(); - html.clone_from(cache.get(&location).unwrap()); - } else { - response.status_code = 301; - } - } - #[cfg(not(feature = "lru_cache"))] - { - if let Ok(index) = location_index(path, &location).await { - html = index; - } else { - response.status_code = 301; - } - } - - buffer = html.into_bytes(); - } else { - // path.is_file() - match File::open(path.clone()).await { - Ok(f) => { - let mut file = f; - mime_type = mime_match(path.to_str().unwrap()); - - #[cfg(feature = "lru_cache")] - { - let mut cache = FILE_CACHE.lock().await; - if let Some(content) = cache.get(&location) { - buffer = content.to_vec(); - } else { - file.read_to_end(&mut buffer).await?; - cache - .push(location.clone(), buffer.clone()) - .to_owned() - .unwrap_or_default(); - } - } - - #[cfg(not(feature = "lru_cache"))] - file.read_to_end(&mut buffer).await?; - - response.send_header( - "Last-Modified", - DateTime::from_timestamp(file.metadata().await?.st_atime(), 0) - .unwrap() - .format(DATE_FORMAT), - ); - } - Err(_) => { - response.status_code = 500; - } - }; - } - } else { - response.status_code = 400; - } - - if response.status_code != 200 { - buffer = status_page(&response.status(response.status_code), server_info) - .await - .into() - } - response.send_header("Content-Length", buffer.len()); - response.send_header("Content-Type", mime_type); - stream.write_all(response.resp().as_bytes()).await?; - stream.write_all(&buffer).await?; - stream.flush().await?; - stream.shutdown().await?; - - Ok((response.status_code, req)) -} - -#[inline] -#[cfg(feature = "log")] -fn init_logger(config: zest::config::Config) { - let mut builder = log4rs::Config::builder(); - - let stdout = ConsoleAppender::builder() - .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) - .target(log4rs::append::console::Target::Stdout) - .build(); - - let stderr = ConsoleAppender::builder() - .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) - .target(log4rs::append::console::Target::Stderr) - .build(); - - let logging = &config.logging.unwrap_or_default(); - builder = if let Some(access_log) = &logging.access_log { - let access_log_path = Path::new(&access_log); - std::fs::File::create(access_log_path).unwrap(); - builder.appender( - Appender::builder().build( - "logfile_access", - Box::new( - FileAppender::builder() - .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) - .build(access_log_path) - .unwrap(), - ), - ), - ) - } else { - builder.appender(Appender::builder().build("logfile_access", Box::new(stdout))) - }; - - builder = if let Some(error_log) = &logging.error_log { - let error_log_path = Path::new(&error_log); - std::fs::File::create(error_log_path).unwrap(); - builder.appender( - Appender::builder().build( - "logfile_error", - Box::new( - FileAppender::builder() - .encoder(Box::new(PatternEncoder::new(LOG_FORMAT))) - .build(error_log_path) - .unwrap(), - ), - ), - ) - } else { - builder.appender(Appender::builder().build("logfile_error", Box::new(stderr))) - }; - - let config = builder - .logger( - Logger::builder() - .appender("logfile_access") - .additive(false) - .build("access", log::LevelFilter::Info), - ) - .logger( - Logger::builder() - .appender("logfile_error") - .additive(false) - .build("error", log::LevelFilter::Error), - ) - .build(Root::builder().build(log::LevelFilter::Off)) - .unwrap(); - - log4rs::init_config(config).unwrap(); -} +use zest::zest::zest_main; #[tokio::main] -async fn main() -> Result<(), Box> { - *CONFIG_PATH.lock()? = ARGS.config.clone().unwrap_or_default(); - let config = CONFIG.deref(); - - set_current_dir(config.clone().server.root)?; - - #[cfg(feature = "log")] - init_logger(config.clone()); - - let listener = - TcpListener::bind(format!("{}:{}", config.bind.addr, config.bind.listen)).await?; - - let mut _allowlist: Option> = config.clone().allowlist; - let mut _blocklist: Option> = config.clone().blocklist; - - let rate_limiter = Arc::new(if let Some(rate_limit) = &config.rate_limit { - Semaphore::new(rate_limit.max_requests) - } else { - Semaphore::new(Semaphore::MAX_PERMITS) - }); - - #[allow(unused_labels)] - 'handle: loop { - #[allow(unused_mut)] - let (mut stream, _addr) = listener.accept().await?; - - #[cfg(feature = "ip_limit")] - { - if let Some(ref allowlist) = _allowlist { - for item in allowlist { - if let Ok(cidr) = item.parse::() { - if !cidr.contains(&_addr.ip()) { - if allowlist.last() != Some(item) { - continue; - } else { - stream.shutdown().await?; - continue 'handle; - } - } - } - } - } - - if let Some(ref blocklist) = _blocklist { - for item in blocklist { - if let Ok(cidr) = item.parse::() { - if cidr.contains(&_addr.ip()) { - stream.shutdown().await?; - continue 'handle; - } - } - } - } - } - - let rate_limiter = Arc::clone(&rate_limiter); - tokio::spawn(async move { - if rate_limiter.clone().try_acquire_owned().is_ok() { - let (_status_code, _req) = handle_connection(stream).await.unwrap_or_default(); - - #[cfg(feature = "log")] - { - match _status_code { - 200 => { - logger().log( - &log::Record::builder() - .level(log::Level::Info) - .target("access") - .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) - .build(), - ); - } - 400.. => { - logger().log( - &log::Record::builder() - .level(log::Level::Error) - .target("error") - .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) - .build(), - ); - } - _ => { - logger().log( - &log::Record::builder() - .level(log::Level::Warn) - .target("access") - .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) - .build(), - ); - } - }; - } - } else { - let _ = stream.shutdown().await; - } - }); - } +async fn main() -> Result<(), Box> { + zest_main().await } diff --git a/src/route.rs b/src/route.rs index bca4b7c..a09077e 100644 --- a/src/route.rs +++ b/src/route.rs @@ -3,7 +3,6 @@ use serde_yml::from_value; use std::{ fmt::Write, io::{ErrorKind, Result}, - ops::Deref, path::{Path, PathBuf}, }; use tokio::fs::{self, read_dir, DirEntry}; @@ -15,7 +14,7 @@ pub fn root_relative(p: &str) -> &str { #[inline] pub async fn location_index(path: PathBuf, location: &str) -> Result { - let config = CONFIG.deref(); + let config = CONFIG.try_read().unwrap(); for (s, v) in &config.locations.clone().unwrap_or_default() { if root_relative(s) == location.trim_end_matches('/') { diff --git a/src/zest.rs b/src/zest.rs new file mode 100644 index 0000000..1fe75a6 --- /dev/null +++ b/src/zest.rs @@ -0,0 +1,384 @@ +use crate::{ + config::{Config, ARGS, CONFIG, CONFIG_PATH, DEFAULT_CONFIG}, + init::init_signal, + route::{location_index, mime_match, root_relative, status_page}, +}; + +use async_rwlock::RwLock; +use chrono::{DateTime, Utc}; +use lazy_static::lazy_static; +use mime::Mime; +use std::{ + collections::HashMap, env::set_current_dir, error::Error, io, ops::Deref, path::Path, sync::Arc, +}; + +#[cfg(feature = "log")] +use {crate::init::init_logger, log::logger}; + +#[cfg(target_os = "android")] +use std::os::android::fs::MetadataExt; + +#[cfg(target_os = "linux")] +use std::os::linux::fs::MetadataExt; + +#[cfg(feature = "lru_cache")] +use { + async_mutex::Mutex, // faster than tokio::sync::Mutex + lru::LruCache, + std::num::NonZeroUsize, +}; + +use tokio::{ + fs::File, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + net::TcpListener, + sync::Semaphore, +}; + +const DATE_FORMAT: &str = "%a, %d %b %Y %H:%M:%S GMT"; + +lazy_static! { + pub static ref T: Arc>> = Arc::new(RwLock::new(None)); +} + +#[cfg(feature = "lru_cache")] +lazy_static! { + static ref INDEX_CACHE: Mutex> = { + let cache = LruCache::new( + NonZeroUsize::new( + DEFAULT_CONFIG + .server + .cache + .clone() + .unwrap_or_default() + .index_capacity + .unwrap_or_default(), + ) + .unwrap(), + ); + Mutex::new(cache) + }; + static ref FILE_CACHE: Mutex>> = { + let cache = LruCache::new( + NonZeroUsize::new( + DEFAULT_CONFIG + .server + .cache + .clone() + .unwrap_or_default() + .file_capacity + .unwrap_or_default(), + ) + .unwrap(), + ); + Mutex::new(cache) + }; +} + +#[derive(Clone)] +struct Response<'a> { + version: &'a str, + status_code: i32, + _headers_buffer: HashMap<&'a str, String>, +} + +impl<'a> Response<'a> { + #[inline] + fn send_header(&mut self, k: &'a str, v: T) -> Option + where + T: ToString, + { + self._headers_buffer.insert(k, v.to_string()) + } + #[inline] + fn resp(&mut self) -> String { + let (version, status_code) = (self.version, self.status_code); + let mut resp = format!("HTTP/{} {}\r\n", version, self.status(status_code)); + for (key, value) in &self._headers_buffer { + resp.push_str(&format!("{}: {}\r\n", key, value)); + } + resp.push_str("\r\n"); + resp + } + #[inline] + fn status(&mut self, status_code: i32) -> String { + let status = match status_code { + 200 => "OK", + 301 => "Moved Permanently", + 400 => "Bad Request", + 404 => "Not Found", + 501 => "Not Implemented", + _ => "Internal Server Error", // 500 + }; + + format!("{} {}", status_code, status) + } +} + +async fn handle_connection(mut stream: S) -> io::Result<(i32, String)> +where + S: AsyncReadExt + AsyncWriteExt + Unpin, +{ + let config = CONFIG.try_read().unwrap(); + + let mut response: Response = Response { + version: "1.1", + status_code: 200, + _headers_buffer: HashMap::new(), + }; + + let server_info = format!( + "Zest/{} ({})", + env!("CARGO_PKG_VERSION"), + config.server.info + ); + response.send_header("Server", server_info.clone()); + + response.send_header("Date", Utc::now().format(DATE_FORMAT)); + + let buf_reader = BufReader::new(&mut stream); + let req = buf_reader.lines().next_line().await?.unwrap_or_default(); + + // GET /location HTTP/1.1 + let parts: Vec<&str> = req.split('/').collect(); + + let mut mime_type: Mime = mime::TEXT_HTML_UTF_8; + let mut buffer: Vec = Vec::new(); + + if parts.len() < 3 { + response.status_code = 400; + } else if parts.first().unwrap().trim() != "GET" { + response.status_code = 501; + } else if let Some(location) = &req.split_whitespace().nth(1) { + let location: String = urlencoding::decode(root_relative(location)) + .unwrap_or_default() + .into(); + + response.version = parts.last().unwrap(); + let mut path = config.server.root.join(location.split('?').next().unwrap()); + + path = match path.canonicalize() { + Ok(canonical_path) => canonical_path, + Err(_) => { + response.status_code = 404; + config + .server + .root + .join(Path::new( + &config + .server + .error_page + .clone() + .unwrap_or("404.html".into()), + )) + .to_path_buf() + .canonicalize() + .unwrap_or_default() + } + }; + if path.is_dir() { + #[allow(unused_assignments)] + let mut html: String = String::new(); + #[cfg(feature = "lru_cache")] + { + let mut cache = INDEX_CACHE.lock().await; + if let Some(ctx) = cache.get(&location) { + html.clone_from(ctx); + } else if let Ok(index) = location_index(path, &location).await { + cache + .push(location.clone(), index) + .to_owned() + .unwrap_or_default(); + html.clone_from(cache.get(&location).unwrap()); + } else { + response.status_code = 301; + } + } + #[cfg(not(feature = "lru_cache"))] + { + if let Ok(index) = location_index(path, &location).await { + html = index; + } else { + response.status_code = 301; + } + } + + buffer = html.into_bytes(); + } else { + // path.is_file() + match File::open(path.clone()).await { + Ok(f) => { + let mut file = f; + mime_type = mime_match(path.to_str().unwrap()); + + #[cfg(feature = "lru_cache")] + { + let mut cache = FILE_CACHE.lock().await; + if let Some(content) = cache.get(&location) { + buffer = content.to_vec(); + } else { + file.read_to_end(&mut buffer).await?; + cache + .push(location.clone(), buffer.clone()) + .to_owned() + .unwrap_or_default(); + } + } + + #[cfg(not(feature = "lru_cache"))] + file.read_to_end(&mut buffer).await?; + + response.send_header( + "Last-Modified", + DateTime::from_timestamp(file.metadata().await?.st_atime(), 0) + .unwrap() + .format(DATE_FORMAT), + ); + } + Err(_) => { + response.status_code = 500; + } + }; + } + } else { + response.status_code = 400; + } + + if response.status_code != 200 { + buffer = status_page(&response.status(response.status_code), server_info) + .await + .into() + } + response.send_header("Content-Length", buffer.len()); + response.send_header("Content-Type", mime_type); + stream.write_all(response.resp().as_bytes()).await?; + stream.write_all(&buffer).await?; + stream.flush().await?; + stream.shutdown().await?; + + Ok((response.status_code, req)) +} + +pub async fn zest_listener(config: C) -> Result<(), Box> +where + C: Deref, +{ + let listener = + TcpListener::bind(format!("{}:{}", config.bind.addr, config.bind.listen)).await?; + + let mut _allowlist: Option> = config.clone().allowlist; + let mut _blocklist: Option> = config.clone().blocklist; + + let rate_limiter = Arc::new(if let Some(rate_limit) = &config.rate_limit { + Semaphore::new(rate_limit.max_requests) + } else { + Semaphore::new(Semaphore::MAX_PERMITS) + }); + + #[allow(unused_labels)] + 'handle: loop { + T.try_read().unwrap().unwrap(); + + #[allow(unused_mut)] + let (mut stream, _addr) = listener.accept().await?; + + #[cfg(feature = "ip_limit")] + { + if let Some(ref allowlist) = _allowlist { + for item in allowlist { + if let Ok(cidr) = item.parse::() { + if !cidr.contains(&_addr.ip()) { + if allowlist.last() != Some(item) { + continue; + } else { + stream.shutdown().await?; + continue 'handle; + } + } + } + } + } + + if let Some(ref blocklist) = _blocklist { + for item in blocklist { + if let Ok(cidr) = item.parse::() { + if cidr.contains(&_addr.ip()) { + stream.shutdown().await?; + continue 'handle; + } + } + } + } + } + + let rate_limiter = Arc::clone(&rate_limiter); + tokio::spawn(async move { + if rate_limiter.clone().try_acquire_owned().is_ok() { + let (_status_code, _req) = handle_connection(stream).await.unwrap_or_default(); + + #[cfg(feature = "log")] + { + match _status_code { + 200 => { + logger().log( + &log::Record::builder() + .level(log::Level::Info) + .target("access") + .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) + .build(), + ); + } + 400.. => { + logger().log( + &log::Record::builder() + .level(log::Level::Error) + .target("error") + .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) + .build(), + ); + } + _ => { + logger().log( + &log::Record::builder() + .level(log::Level::Warn) + .target("access") + .args(format_args!("\"{}\" {} - {}", _req, _status_code, _addr)) + .build(), + ); + } + }; + } + } else { + let _ = stream.shutdown().await; + } + }); + } +} + +pub async fn zest_main() -> Result<(), Box> { + *CONFIG_PATH.lock()? = ARGS.config.clone().unwrap_or_default(); + let config = DEFAULT_CONFIG.deref(); + + set_current_dir(config.clone().server.root)?; + + #[cfg(feature = "log")] + init_logger(&config.clone()).await; + + init_signal().await.unwrap(); + + loop { + let _config = CONFIG.try_read().unwrap(); + let config = _config.clone(); + drop(_config); + + let handle = tokio::spawn(async move { + zest_listener(&config.clone()).await.unwrap(); + }); + + let mut t = T.try_write().unwrap(); + *t = Some(0); + drop(t); + + let _ = handle.await; + } +}