diff --git a/Cargo.toml b/Cargo.toml index 42c5bc3..14f9a92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,14 +17,15 @@ repository = "https://github.com/bnjjj/tide-validator" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tide = "0.8.0" -futures = "0.3.4" -serde = { version = "1.0.106", features = ["derive"] } -async-std = { version = "1.5.0", features = ["attributes"] } +tide = "0.16.0" +async-std = { version = "1.8.0", features = ["attributes"] } +serde = { version = "1.0", features = ["derive"] } +dotenv = "0.15.0" +futures = "0.3" +serde_json = "1.0.64" [dev-dependencies] -async-std = "1.5.0" -http = "0.2.1" +http = "0.2.4" http-service-mock = "0.5.0" http-service = "0.5.0" -serde_json = "1.0.52" + diff --git a/src/lib.rs b/src/lib.rs index d0464a0..73c1220 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,13 +146,13 @@ //! //! For more details about examples check out [the `examples` directory on GitHub](https://github.com/bnjjj/tide-validator/tree/master/examples) +use serde::Serialize; +use serde_json::json; use std::collections::HashMap; use std::str::FromStr; use std::{fmt::Debug, sync::Arc}; +use tide::{http::headers::HeaderName, Body, Middleware, Next, Request, Response, StatusCode}; -use futures::future::BoxFuture; -use serde::Serialize; -use tide::{http::headers::HeaderName, Middleware, Next, Request, Response, StatusCode}; // trait Validator = Fn(&str) -> Result<(), String> + Send + Sync + 'static; /// Enum to indicate on which HTTP field you want to make validations @@ -191,27 +191,6 @@ impl ValidatorMiddleware where T: Serialize + Send + Sync + 'static, { - /// Create a new ValidatorMiddleware to put in your tide configuration. - /// - /// # Example - /// - /// ```rust,no_run,compile_fail - /// fn main() -> io::Result<()> { - /// task::block_on(async { - /// let mut app = tide::new(); - /// - /// let mut validator_middleware = ValidatorMiddleware::new(); - /// validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number); - /// - /// app.at("/test/:n").middleware(validator_middleware).get( - /// |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) }, - /// ); - /// - /// app.listen("127.0.0.1:8080").await?; - /// Ok(()) - /// }) - /// } - /// ``` pub fn new() -> Self { ValidatorMiddleware { validators: HashMap::new(), @@ -227,29 +206,6 @@ where } self } - - /// Add new validator for your middleware - /// - /// # Example - /// - /// ```rust,no_run,compile_fail - /// fn main() -> io::Result<()> { - /// task::block_on(async { - /// let mut app = tide::new(); - /// - /// let mut validator_middleware = ValidatorMiddleware::new(); - /// validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number); - /// validator_middleware.add_validator(HttpField::QueryParam("myqueryparam"), is_required); - /// - /// app.at("/test/:n").middleware(validator_middleware).get( - /// |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) }, - /// ); - /// - /// app.listen("127.0.0.1:8080").await?; - /// Ok(()) - /// }) - /// } - /// ``` pub fn add_validator(&mut self, param_name: HttpField<'static>, validator: F) where F: Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static, @@ -263,309 +219,81 @@ where } } +#[tide::utils::async_trait] impl Middleware for ValidatorMiddleware where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, T: Serialize + Send + Sync + 'static, { - fn handle<'a>( - &'a self, - ctx: Request, - next: Next<'a, State>, - ) -> BoxFuture<'a, tide::Result> { - Box::pin(async move { - let mut query_parameters: Option> = None; - - for (param_name, validators) in &self.validators { - match param_name { - HttpField::Param(param_name) => { - for validator in validators { - let param_found: Result = ctx.param(param_name); - if let Err(err) = - validator(param_name, param_found.ok().as_ref().map(|p| &p[..])) - { - return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else( - |err| { - Response::new(StatusCode::InternalServerError).body_string(format!( - "cannot serialize your parameter validator for '{}' error : {:?}", - param_name, - err - )) - }, - )); + async fn handle(&self, ctx: Request, next: Next<'_, State>) -> tide::Result { + let mut query_parameters: Option> = None; + for (param_name, validators) in &self.validators { + match param_name { + HttpField::Param(param_name) => { + for validator in validators { + // let param_found = ctx.param(param_name).unwrap(); + // // let opt: Option = Some(param_found.to_owned()); + // // let value = opt.as_ref().map(|x| &**x).unwrap_or(""); + // if let Err(err) = validator(param_name, Some(param_found)) { + // let mut response = Response::new(StatusCode::BadRequest); + // let body_json = Body::from_json(&json!(&err))?; + // response.set_body(body_json); + // return Ok(response); + // } + + match ctx.param(param_name) { + Err(_err) => { + return Ok(Response::new(StatusCode::BadRequest)); } - } - } - HttpField::QueryParam(param_name) => { - if query_parameters.is_none() { - match ctx.query::>() { - Err(err) => { - return Ok(Response::new(StatusCode::InternalServerError) - .body_string(format!( - "cannot read query parameters: {:?}", - err - ))); + Ok(param_found) => { + if let Err(_err) = validator(param_name, Some(param_found)) { + return Ok(Response::new(StatusCode::InternalServerError)); } - Ok(qps) => query_parameters = Some(qps), - } - } - let query_parameters = query_parameters.as_ref().unwrap(); - - for validator in validators { - if let Err(err) = validator( - param_name, - query_parameters.get(¶m_name[..]).map(|p| &p[..]), - ) { - return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else( - |err| { - Response::new(StatusCode::InternalServerError).body_string(format!( - "cannot serialize your query parameter validator for '{}' error : {:?}", - param_name, - err - )) - }, - )); } } } - HttpField::Header(header_name) => { - for validator in validators { - let header_found: Option<&str> = ctx - .header(&HeaderName::from_str(header_name).unwrap()) - .map(|header| header.last().map(|val| val.as_str()).unwrap()); - if let Err(err) = validator(header_name, header_found) { - return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else( - |err| { - Response::new(StatusCode::InternalServerError).body_string(format!( - "cannot serialize your header validator for '{}' error : {:?}", - header_name, - err - )) - }, - )); - } + } + HttpField::QueryParam(param_name) => { + if query_parameters.is_none() { + match ctx.query::>() { + Err(_err) => return Ok(Response::new(StatusCode::InternalServerError)), + Ok(qps) => query_parameters = Some(qps), } } - HttpField::Cookie(cookie_name) => { - for validator in validators { - let cookie_found = ctx.cookie(cookie_name); - if let Err(err) = - validator(cookie_name, cookie_found.as_ref().map(|c| c.value())) - { - return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else( - |err| { - Response::new(StatusCode::InternalServerError).body_string(format!( - "cannot serialize your cookie validator for '{}' error : {:?}", - cookie_name, - err - )) - }, - )); - } + let query_parameters = query_parameters.as_ref().unwrap(); + + for validator in validators { + if let Err(_err) = validator( + param_name, + query_parameters.get(¶m_name[..]).map(|p| &p[..]), + ) { + return Ok(Response::new(StatusCode::InternalServerError)); } } } - } - next.run(ctx).await - }) - } -} - -#[cfg(test)] -mod tests { - - use super::{HttpField, StatusCode, ValidatorMiddleware}; - - use super::*; - use async_std::io::prelude::*; - use futures::executor::block_on; - use http_service_mock::make_server; - use serde::{Deserialize, Serialize}; - use tide::http::{Method, Request}; - - #[inline] - fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), String> { - if let Some(field_value) = field_value { - if field_value.parse::().is_err() { - return Err(format!( - "field '{}' = '{}' is not a valid number", - field_name, field_value - )); - } - } - - Ok(()) - } - - #[test] - fn validator_simple() { - let mut inner = tide::new(); - let mut validators = ValidatorMiddleware::new(); - validators.add_validator(HttpField::Param("bar"), is_number); - inner - .at("/foo/:bar") - .middleware(validators) - .get(|_| async { Ok("foo") }); - - let mut server = make_server(inner).unwrap(); - - let mut buf = Vec::new(); - let req = Request::new(Method::Get, "http://localhost/foo/4".parse().unwrap()); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), 200); - block_on(res.read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"foo"); - - buf.clear(); - let req = Request::new(Method::Get, "http://localhost/foo/bar".parse().unwrap()); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), StatusCode::BadRequest); - block_on(res.read_to_end(&mut buf)).unwrap(); - assert_eq!( - String::from_utf8_lossy(&buf[..]), - String::from(r#""field 'bar' = 'bar' is not a valid number""#) - ); - } - - #[derive(Debug, Serialize, Deserialize)] - struct CustomError { - status_code: usize, - message: String, - } - - fn is_length_under( - max_length: usize, - ) -> Box) -> Result<(), CustomError> + Send + Sync + 'static> { - Box::new( - move |field_name: &str, field_value: Option<&str>| -> Result<(), CustomError> { - if let Some(field_value) = field_value { - if field_value.len() > max_length { - let my_error = CustomError { - status_code: 400, - message: format!( - "element '{}' which is equals to '{}' have not the maximum length of {}", - field_name, field_value, max_length - ), - }; - return Err(my_error); + HttpField::Header(header_name) => { + for validator in validators { + let x_header = &HeaderName::from_str(header_name).unwrap(); + let header_found = ctx.header(x_header.as_str()); + let c = header_found.map(|h| h.last().as_str()); + + if let Err(_err) = validator(header_name, c) { + return Ok(Response::new(StatusCode::InternalServerError)); + } } } - Ok(()) - }, - ) - } - - #[test] - fn validator_custom() { - let mut inner = tide::new(); - let mut validators = ValidatorMiddleware::new(); - validators.add_validator(HttpField::QueryParam("test"), is_length_under(10)); - validators.add_validator(HttpField::Cookie("session"), is_length_under(10)); - inner - .at("/foo") - .middleware(validators) - .get(|_| async { Ok("foo") }); - - let mut server = make_server(inner).unwrap(); - - let mut buf = Vec::new(); - let req = Request::new( - Method::Get, - "http://localhost/foo?test=coucou".parse().unwrap(), - ); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), 200); - block_on(res.read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"foo"); - - buf.clear(); - - let req = Request::new( - Method::Get, - "http://localhost/foo?test=blablablablabla".parse().unwrap(), - ); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), StatusCode::BadRequest); - block_on(res.read_to_end(&mut buf)).unwrap(); - - let err: CustomError = serde_json::from_slice(&buf[..]).unwrap(); - - assert_eq!(err.status_code, 400usize); - assert_eq!( - err.message, - String::from("element 'test' which is equals to 'blablablablabla' have not the maximum length of 10") - ); - } - - #[inline] - fn is_bool(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> { - if let Some(field_value) = field_value { - match field_value { - "true" | "false" => return Ok(()), - other => { - return Err(CustomError { - status_code: 400, - message: format!( - "field '{}' = '{}' is not a valid boolean", - field_name, other - ), - }) + HttpField::Cookie(cookie_name) => { + for validator in validators { + let cookie_found = ctx.cookie(cookie_name); + if let Err(_err) = + validator(cookie_name, cookie_found.as_ref().map(|c| c.value())) + { + return Ok(Response::new(StatusCode::InternalServerError)); + } + } } } } - Ok(()) - } - - #[inline] - fn is_required(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> { - if field_value.is_none() { - Err(CustomError { - status_code: 400, - message: format!("'{}' is mandatory", field_name), - }) - } else { - Ok(()) - } - } - - #[test] - fn validator_chains() { - let mut inner = tide::new(); - let mut validators = ValidatorMiddleware::new(); - validators.add_validator(HttpField::QueryParam("test"), is_length_under(10)); - validators.add_validator(HttpField::Header("X-Is-Connected"), is_required); - validators.add_validator(HttpField::Header("X-Is-Connected"), is_bool); - inner - .at("/foo") - .middleware(validators) - .get(|_| async { Ok("foo") }); - - let mut server = make_server(inner).unwrap(); - - let mut buf = Vec::new(); - - let mut req = Request::new( - Method::Get, - "http://localhost/foo?test=coucou".parse().unwrap(), - ); - req.insert_header("X-Is-Connected", "true").unwrap(); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), 200); - block_on(res.read_to_end(&mut buf)).unwrap(); - assert_eq!(&*buf, &*b"foo"); - - buf.clear(); - let req = Request::new( - Method::Get, - "http://localhost/foo?test=coucou".parse().unwrap(), - ); - let mut res = server.simulate(req).unwrap(); - assert_eq!(res.status(), StatusCode::BadRequest); - block_on(res.read_to_end(&mut buf)).unwrap(); - - let err: CustomError = serde_json::from_slice(&buf[..]).unwrap(); - - assert_eq!(err.status_code, 400usize); - assert_eq!(err.message, String::from("'X-Is-Connected' is mandatory")); + Ok(next.run(ctx).await) } }