diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index ff62a0ae87..34fc957817 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -36,6 +36,7 @@ memchr = "2" stable-pattern = "0.1" cookie = { version = "0.18", features = ["percent-encode"] } state = "0.6" +headers = "0.4.0" [dependencies.serde] version = "1.0" diff --git a/core/http/src/header/header.rs b/core/http/src/header/header.rs index 11fa6f94b4..af03a9c680 100644 --- a/core/http/src/header/header.rs +++ b/core/http/src/header/header.rs @@ -1,6 +1,8 @@ +use core::str; use std::borrow::{Borrow, Cow}; use std::fmt; +use headers::{Header as HHeader, HeaderValue}; use indexmap::IndexMap; use crate::uncased::{Uncased, UncasedStr}; @@ -798,10 +800,153 @@ impl From<&cookie::Cookie<'_>> for Header<'static> { } } +/// A destination for `HeaderValue`s that can be used to accumulate +/// a single header value using from hyperium headers' decode protocol. +#[derive(Default)] +struct HeaderValueDestination { + value: Option, + count: usize, +} + +impl <'r>HeaderValueDestination { + fn into_value(self) -> HeaderValue { + if let Some(value) = self.value { + // TODO: if value.count > 1, then log that multiple header values are + // generated by the typed header, but that the dropped. + value + } else { + // Perhaps log that the typed header didn't create any values. + // This won't happen in the current implementation (headers 0.4.0). + HeaderValue::from_static("") + } + } + + fn into_header_string(self) -> Cow<'static, str> { + let value = self.into_value(); + // TODO: Optimize if we know this is a static reference. + value.to_str().unwrap_or("").to_string().into() + } +} + +impl Extend for HeaderValueDestination { + fn extend>(&mut self, iter: T) { + for value in iter { + self.count += 1; + if self.value.is_none() { + self.value = Some(value) + } + } + } +} + +macro_rules! import_typed_headers { +($($name:ident),*) => ($( + pub use headers::$name; + + impl ::std::convert::From for Header<'static> { + fn from(header: self::$name) -> Self { + let mut destination = HeaderValueDestination::default(); + header.encode(&mut destination); + let name = self::$name::name(); + Header::new(name.as_str(), destination.into_header_string()) + } + } +)*) +} + +macro_rules! import_generic_typed_headers { +($($name:ident<$bound:ident>),*) => ($( + pub use headers::$name; + + impl ::std::convert::From> + for Header<'static> { + fn from(header: self::$name) -> Self { + let mut destination = HeaderValueDestination::default(); + header.encode(&mut destination); + let name = self::$name::::name(); + Header::new(name.as_str(), destination.into_header_string()) + } + } +)*) +} + +// The following headers from 'headers' 0.4 are not imported, since they are +// provided by other Rocket features. + +// * ContentType, // Content-Type header, defined in RFC7231 +// * Cookie, // Cookie header, defined in RFC6265 +// * Host, // The Host header. +// * Location, // Location header, defined in RFC7231 +// * SetCookie, // Set-Cookie header, defined RFC6265 + +import_typed_headers! { + AcceptRanges, // Accept-Ranges header, defined in RFC7233 + AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS + AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS + AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS + AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS + AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS + AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS + AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS + AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS + Age, // Age header, defined in RFC7234 + Allow, // Allow header, defined in RFC7231 + CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246 + Connection, // Connection header, defined in RFC7230 + ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266. + ContentEncoding, // Content-Encoding header, defined in RFC7231 + ContentLength, // Content-Length header, defined in RFC7230 + ContentLocation, // Content-Location header, defined in RFC7231 + ContentRange, // Content-Range, described in RFC7233 + Date, // Date header, defined in RFC7231 + ETag, // ETag header, defined in RFC7232 + Expect, // The Expect header. + Expires, // Expires header, defined in RFC7234 + IfMatch, // If-Match header, defined in RFC7232 + IfModifiedSince, // If-Modified-Since header, defined in RFC7232 + IfNoneMatch, // If-None-Match header, defined in RFC7232 + IfRange, // If-Range header, defined in RFC7233 + IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232 + LastModified, // Last-Modified header, defined in RFC7232 + Origin, // The Origin header. + Pragma, // The Pragma header defined by HTTP/1.0. + Range, // Range header, defined in RFC7233 + Referer, // Referer header, defined in RFC7231 + ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy + RetryAfter, // The Retry-After header. + SecWebsocketAccept, // The Sec-Websocket-Accept header. + SecWebsocketKey, // The Sec-Websocket-Key header. + SecWebsocketVersion, // The Sec-Websocket-Version header. + Server, // Server header, defined in RFC7231 + StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797 + Te, // TE header, defined in RFC7230 + TransferEncoding, // Transfer-Encoding header, defined in RFC7230 + Upgrade, // Upgrade header, defined in RFC7230 + UserAgent, // User-Agent header, defined in RFC7231 + Vary // Vary header, defined in RFC7231 +} + +import_generic_typed_headers! { + Authorization, // Authorization header, defined in RFC7235 + ProxyAuthorization // Proxy-Authorization header, defined in RFC7235 +} + +pub use headers::authorization::Credentials; + #[cfg(test)] mod tests { + use std::time::SystemTime; + use super::HeaderMap; + #[test] + fn add_typed_header() { + use super::LastModified; + let mut map = HeaderMap::new(); + map.add(LastModified::from(SystemTime::now())); + assert!(map.get_one("last-modified").unwrap().contains("GMT")); + } + #[test] fn case_insensitive_add_get() { let mut map = HeaderMap::new(); diff --git a/core/http/src/header/mod.rs b/core/http/src/header/mod.rs index 653b786348..a92406b730 100644 --- a/core/http/src/header/mod.rs +++ b/core/http/src/header/mod.rs @@ -9,7 +9,18 @@ mod proxy_proto; pub use self::content_type::ContentType; pub use self::accept::{Accept, QMediaType}; pub use self::media_type::MediaType; -pub use self::header::{Header, HeaderMap}; +pub use self::header::{ + Header, HeaderMap, AcceptRanges, AccessControlAllowCredentials, + AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, + AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders, + AccessControlRequestMethod, Age, Allow, CacheControl, Connection, ContentDisposition, + ContentEncoding, ContentLength, ContentLocation, ContentRange, Date, ETag, Expect, + Expires, IfMatch, IfModifiedSince, IfNoneMatch, IfRange, IfUnmodifiedSince, + LastModified, Origin, Pragma, Range, Referer, ReferrerPolicy, RetryAfter, + SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Server, StrictTransportSecurity, + Te, TransferEncoding, Upgrade, UserAgent, Vary, Authorization, ProxyAuthorization, + Credentials +}; pub use self::proxy_proto::ProxyProto; pub(crate) use self::media_type::Source; diff --git a/core/lib/tests/typed-headers.rs b/core/lib/tests/typed-headers.rs new file mode 100644 index 0000000000..e304e021a4 --- /dev/null +++ b/core/lib/tests/typed-headers.rs @@ -0,0 +1,30 @@ +#[macro_use] +extern crate rocket; + +use std::time::{Duration, SystemTime}; +use rocket::http::Expires; + +#[derive(Responder)] +struct MyResponse { + body: String, + expires: Expires, +} + +#[get("/")] +fn index() -> MyResponse { + let some_future_time = + SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(60 * 60 * 24 * 365 * 100)).unwrap(); + + MyResponse { + body: "Hello, world!".into(), + expires: Expires::from(some_future_time) + } +} + +#[test] +fn typed_header() { + let rocket = rocket::build().mount("/", routes![index]); + let client = rocket::local::blocking::Client::debug(rocket).unwrap(); + let response = client.get("/").dispatch(); + assert_eq!(response.headers().get_one("Expires").unwrap(), "Sat, 07 Dec 2069 00:00:00 GMT"); +}