Skip to content

Commit

Permalink
Implement FromRequest for typed headers rwf2#1283
Browse files Browse the repository at this point in the history
  • Loading branch information
jespersm committed Feb 11, 2021
1 parent 5f834c1 commit 3d5db5e
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 1 deletion.
3 changes: 2 additions & 1 deletion core/http/src/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
pub mod header {
use super::super::header::Header;
pub use hyperx::header::Header as HyperxHeaderTrait;

pub use hyperx::header::Raw;

macro_rules! import_http_headers {
($($name:ident),*) => ($(
pub use http::header::$name as $name;
Expand Down
91 changes: 91 additions & 0 deletions core/lib/src/request/from_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::fmt::Debug;
use std::net::{IpAddr, SocketAddr};

use futures::future::BoxFuture;
use crate::http::hyper::header::HyperxHeaderTrait;
use crate::http::hyper::header::Raw;

use crate::router::Route;
use crate::request::Request;
Expand Down Expand Up @@ -459,6 +461,95 @@ impl<'a, 'r> FromRequest<'a, 'r> for IpAddr {
}
}


macro_rules! from_typed_headers {
($($name:ident),*) => ($(
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for $name {

type Error = std::convert::Infallible;

async fn from_request(request: &'a Request<'r>) ->
Outcome<Self, Self::Error> {
let headers = request.headers().get($name::header_name());
let vv : Vec<Vec<u8>> = headers.map(
|s| Vec::from(s.as_bytes())).collect();
if vv.is_empty() {
Forward(())
} else {
match $name::parse_header(& Raw::from(vv)) {
Ok(v) => Success(v),
Err(_) => Forward(())
}
}
}
}
)*)
}

macro_rules! from_generic_typed_headers {
($($name:ident<$bound:ident>),*) => ($(

#[crate::async_trait]
impl<'a, 'r, T: 'static + $bound> FromRequest<'a, 'r> for $name<T> {

type Error = std::convert::Infallible;

async fn from_request(request: &'a Request<'r>) ->
Outcome<Self, Self::Error> {
let headers = request.headers().get($name::<T>::header_name());
let vv : Vec<Vec<u8>> = headers.map(
|s| Vec::from(s.as_bytes())).collect();
if vv.is_empty() {
Forward(())
} else {
match $name::parse_header(& Raw::from(vv)) {
Ok(v) => Success(v),
Err(_) => Forward(())
}
}
}
}
)*)
}

use crate::http::hyper::header::{
Accept as AcceptHX, AcceptCharset, AcceptEncoding, AcceptLanguage,
AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders,
AccessControlAllowMethods, AccessControlAllowOrigin,
AccessControlExposeHeaders, AccessControlMaxAge,
AccessControlRequestHeaders, AccessControlRequestMethod, Allow,
CacheControl, Connection, ContentDisposition, ContentEncoding,
ContentLanguage, ContentLength, ContentLocation, ContentRange,
ContentType as ContentTypeHX, Cookie as CookieHX, Date, ETag, Expires,
Expect, From, Host, IfMatch, IfModifiedSince, IfNoneMatch,
IfUnmodifiedSince, IfRange, LastEventId, LastModified, Link, Location,
Origin as OriginHX, Pragma, Prefer, PreferenceApplied,
Range, Referer, ReferrerPolicy, RetryAfter, Server, SetCookie,
StrictTransportSecurity, Te, TransferEncoding, Upgrade, UserAgent,
Vary, Warning, Authorization, ProxyAuthorization, Scheme};

from_typed_headers! {
AcceptHX, AcceptCharset, AcceptEncoding, AcceptLanguage, AcceptRanges,
AccessControlAllowCredentials, AccessControlAllowHeaders,
AccessControlAllowMethods, AccessControlAllowOrigin,
AccessControlExposeHeaders, AccessControlMaxAge,
AccessControlRequestHeaders, AccessControlRequestMethod, Allow,
CacheControl, Connection, ContentDisposition, ContentEncoding,
ContentLanguage, ContentLength, ContentLocation, ContentRange,
ContentTypeHX, CookieHX, Date, ETag, Expires, Expect, From, Host, IfMatch,
IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, IfRange, LastEventId,
LastModified, Link, Location, OriginHX, Pragma, Prefer, PreferenceApplied,
Range, Referer, ReferrerPolicy, RetryAfter, Server, SetCookie,
StrictTransportSecurity, Te, TransferEncoding, Upgrade, UserAgent, Vary,
Warning
}
from_generic_typed_headers! {
Authorization<Scheme>,
ProxyAuthorization<Scheme>
}


#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr {
type Error = std::convert::Infallible;
Expand Down
74 changes: 74 additions & 0 deletions core/lib/tests/typed-header-param-1283.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use rocket;

use rocket::{get, routes};
use rocket::http::Status;
use rocket::local::blocking::Client;
use rocket::http::hyper::header::{AcceptLanguage, qitem, Authorization, Bearer};
use language_tags::langtag;

#[get("/lang-required")]
fn lang_required(lang: AcceptLanguage) -> String {
format!("Accept-Language: {}", lang)
}

#[get("/lang-required", rank = 2)]
fn no_lang() -> Status {
Status::BadRequest
}

#[test]
fn test_required_header() {
let rocket = rocket::ignite().mount("/", routes![lang_required, no_lang]);
let client = Client::tracked(rocket).unwrap();

// Will hit no_lang() above
let response = client.get("/lang-required").dispatch();
assert_eq!(response.status(), Status::BadRequest);

// Will hit the lang_required() route
let request = client.get("/lang-required");
let response = request.header(AcceptLanguage(vec![qitem(langtag!(da))])).dispatch();
assert_eq!(response.into_string().unwrap(), "Accept-Language: da");
}

#[get("/lang-optional")]
fn lang_optional(lang: Option<AcceptLanguage>) -> String {
if let Some(lang) = lang {
format!("Accept-Language: {}", lang)
} else {
format!("English is the lingua franca of the internet")
}
}

#[test]
fn test_optional_header() {
let rocket = rocket::ignite().mount("/", routes![lang_optional]);
let client = Client::tracked(rocket).unwrap();

// When header is present
let request = client.get("/lang-optional");
let response = request.header(AcceptLanguage(vec![qitem(langtag!(da))])).dispatch();
assert_eq!(response.into_string().unwrap(), "Accept-Language: da");

// When header is elided
let response = client.get("/lang-optional").dispatch();
assert_eq!(response.into_string().unwrap(), "English is the lingua franca of the internet");
}

#[get("/spill-beans")]
fn spill_beans(auth: Authorization<Bearer>) -> String {
format!("I'll tell you secrets: {:?}", auth.0.token)
}

#[test]
fn test_generic_header() {
let rocket = rocket::ignite().mount("/", routes![spill_beans]);
let client = Client::tracked(rocket).unwrap();

let request = client.get("/spill-beans");
let response = request.header(Authorization(Bearer {
token: "aaa".to_owned()
})).dispatch();
assert_eq!(response.into_string().unwrap(), "I'll tell you secrets: \"aaa\"");
}

0 comments on commit 3d5db5e

Please sign in to comment.