Skip to content

Commit

Permalink
Initial brush
Browse files Browse the repository at this point in the history
  • Loading branch information
the10thWiz committed Jun 20, 2024
1 parent 6857b82 commit 56e7fa6
Show file tree
Hide file tree
Showing 15 changed files with 212 additions and 46 deletions.
16 changes: 14 additions & 2 deletions core/codegen/src/attribute/catch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,25 @@ pub fn _catch(
.map(|ty| ty.span())
.unwrap_or_else(Span::call_site);

// TODO: how to handle request?
// - Right now: (), (&Req), (Status, &Req) allowed
// - New: (), (&E), (&Req, &E), (Status, &Req, &E)
// Set the `req` and `status` spans to that of their respective function
// arguments for a more correct `wrong type` error span. `rev` to be cute.
let codegen_args = &[__req, __status];
let codegen_args = &[__req, __status, __error];
let inputs = catch.function.sig.inputs.iter().rev()
.zip(codegen_args.iter())
.map(|(fn_arg, codegen_arg)| match fn_arg {
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
}).rev();
let make_error = if let Some(arg) = catch.function.sig.inputs.iter().rev().next() {
quote_spanned!(arg.span() =>
// let
)
} else {
quote! {}
};

// We append `.await` to the function call if this is `async`.
let dot_await = catch.function.sig.asyncness
Expand All @@ -68,9 +78,11 @@ pub fn _catch(
fn into_info(self) -> #_catcher::StaticInfo {
fn monomorphized_function<'__r>(
#__status: #Status,
#__req: &'__r #Request<'_>
#__req: &'__r #Request<'_>,
__error_init: &#ErasedErrorRef<'__r>,
) -> #_catcher::BoxFuture<'__r> {
#_Box::pin(async move {
#make_error
let __response = #catcher_response;
#Response::build()
.status(#__status)
Expand Down
16 changes: 10 additions & 6 deletions core/codegen/src/attribute/route/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn query_decls(route: &Route) -> Option<TokenStream> {
fn request_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
define_spanned_export!(ty.span() =>
__req, __data, _request, display_hack, FromRequest, Outcome
__req, __data, _request, display_hack, FromRequest, Outcome, ErrorResolver, ErrorDefault
);

quote_spanned! { ty.span() =>
Expand All @@ -150,11 +150,13 @@ fn request_guard_decl(guard: &Guard) -> TokenStream {
target: concat!("rocket::codegen::route::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
reason = %#display_hack!(__e),
reason = %#display_hack!(&__e),
"request guard failed"
);

return #Outcome::Error(__c);
#[allow(unused)]
use #ErrorDefault;
return #Outcome::Error((__c, #ErrorResolver::new(__e).cast()));
}
};
}
Expand Down Expand Up @@ -219,7 +221,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {

fn data_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome);
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome, ErrorResolver, ErrorDefault);

quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await {
Expand All @@ -243,11 +245,13 @@ fn data_guard_decl(guard: &Guard) -> TokenStream {
target: concat!("rocket::codegen::route::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
reason = %#display_hack!(__e),
reason = %#display_hack!(&__e),
"data guard failed"
);

return #Outcome::Error(__c);
#[allow(unused)]
use #ErrorDefault;
return #Outcome::Error((__c, #ErrorResolver::new(__e).cast()));
}
};
}
Expand Down
3 changes: 3 additions & 0 deletions core/codegen/src/exports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ define_exported_paths! {
Route => ::rocket::Route,
Catcher => ::rocket::Catcher,
Status => ::rocket::http::Status,
ErrorResolver => ::rocket::catcher::resolution::Resolve,
ErrorDefault => ::rocket::catcher::resolution::DefaultTypeErase,
ErasedErrorRef => ::rocket::catcher::ErasedErrorRef,
}

macro_rules! define_spanned_export {
Expand Down
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
cookie = { version = "0.18", features = ["percent-encode"] }
futures = { version = "0.3.30", default-features = false, features = ["std"] }
state = "0.6"
transient = { version = "0.2.0", path = "../../../transient" }

# tracing
tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] }
Expand Down
28 changes: 15 additions & 13 deletions core/lib/src/catcher/catcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::request::Request;
use crate::http::{Status, ContentType, uri};
use crate::catcher::{Handler, BoxFuture};

use super::ErasedErrorRef;

/// An error catching route.
///
/// Catchers are routes that run when errors are produced by the application.
Expand Down Expand Up @@ -147,20 +149,20 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
///
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) })
/// }
///
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// let res = (status, format!("{}: {}", status, req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
Expand Down Expand Up @@ -199,11 +201,11 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
Expand All @@ -225,12 +227,12 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::response::Responder;
/// use rocket::http::Status;
/// # use rocket::uri;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
Expand Down Expand Up @@ -279,11 +281,11 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
Expand Down Expand Up @@ -313,7 +315,7 @@ impl Catcher {

impl Default for Catcher {
fn default() -> Self {
fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
fn handler<'r>(s: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
Box::pin(async move { Ok(default_handler(s, req)) })
}

Expand All @@ -331,7 +333,7 @@ pub struct StaticInfo {
/// The catcher's status code.
pub code: Option<u16>,
/// The catcher's handler, i.e, the annotated function.
pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>,
pub handler: for<'r> fn(Status, &'r Request<'_>, &ErasedErrorRef<'r>) -> BoxFuture<'r>,
/// The file, line, and column where the catcher was defined.
pub location: (&'static str, u32, u32),
}
Expand Down Expand Up @@ -418,7 +420,7 @@ macro_rules! default_handler_fn {

pub(crate) fn default_handler<'r>(
status: Status,
req: &'r Request<'_>
req: &'r Request<'_>,
) -> Response<'r> {
let preferred = req.accept().map(|a| a.preferred());
let (mime, text) = if preferred.map_or(false, |a| a.is_json()) {
Expand Down
18 changes: 11 additions & 7 deletions core/lib/src/catcher/handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{Request, Response};
use crate::http::Status;

use super::ErasedErrorRef;

/// Type alias for the return type of a [`Catcher`](crate::Catcher)'s
/// [`Handler::handle()`].
pub type Result<'r> = std::result::Result<Response<'r>, crate::http::Status>;
Expand Down Expand Up @@ -29,7 +31,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
/// and used as follows:
///
/// ```rust,no_run
/// use rocket::{Request, Catcher, catcher};
/// use rocket::{Request, Catcher, catcher::{self, ErasedErrorRef}};
/// use rocket::response::{Response, Responder};
/// use rocket::http::Status;
///
Expand All @@ -45,7 +47,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
///
/// #[rocket::async_trait]
/// impl catcher::Handler for CustomHandler {
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> catcher::Result<'r> {
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> catcher::Result<'r> {
/// let inner = match self.0 {
/// Kind::Simple => "simple".respond_to(req)?,
/// Kind::Intermediate => "intermediate".respond_to(req)?,
Expand Down Expand Up @@ -97,30 +99,32 @@ pub trait Handler: Cloneable + Send + Sync + 'static {
/// Nevertheless, failure is allowed, both for convenience and necessity. If
/// an error handler fails, Rocket's default `500` catcher is invoked. If it
/// succeeds, the returned `Response` is used to respond to the client.
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> Result<'r>;
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: &ErasedErrorRef<'r>) -> Result<'r>;
}

// We write this manually to avoid double-boxing.
impl<F: Clone + Sync + Send + 'static> Handler for F
where for<'x> F: Fn(Status, &'x Request<'_>) -> BoxFuture<'x>,
where for<'x> F: Fn(Status, &'x Request<'_>, &ErasedErrorRef<'x>) -> BoxFuture<'x>,
{
fn handle<'r, 'life0, 'life1, 'async_trait>(
fn handle<'r, 'life0, 'life1, 'life2, 'async_trait>(
&'life0 self,
status: Status,
req: &'r Request<'life1>,
error: &'life2 ErasedErrorRef<'r>,
) -> BoxFuture<'r>
where 'r: 'async_trait,
'life0: 'async_trait,
'life1: 'async_trait,
'life2: 'async_trait,
Self: 'async_trait,
{
self(status, req)
self(status, req, error)
}
}

// Used in tests! Do not use, please.
#[doc(hidden)]
pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>) -> BoxFuture<'r> {
pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>, _: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
Box::pin(async move { Ok(Response::new()) })
}

Expand Down
2 changes: 2 additions & 0 deletions core/lib/src/catcher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
mod catcher;
mod handler;
mod types;

pub use catcher::*;
pub use handler::*;
pub use types::*;
110 changes: 110 additions & 0 deletions core/lib/src/catcher/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use transient::{Any, CanRecoverFrom, Co, Transient, Downcast};

pub type ErasedError<'r> = Box<dyn Any<Co<'r>> + Send + Sync + 'r>;
pub type ErasedErrorRef<'r> = dyn Any<Co<'r>> + Send + Sync + 'r;

pub fn default_error_type<'r>() -> ErasedError<'r> {
Box::new(())
}

pub fn downcast<'a, 'r, T: Transient + 'r>(v: &'a ErasedErrorRef<'r>) -> Option<&'a T>
where T::Transience: CanRecoverFrom<Co<'r>>
{
v.downcast_ref()
}

// /// Chosen not to expose this macro, since it's pretty short and sweet
// #[doc(hidden)]
// #[macro_export]
// macro_rules! resolve_typed_catcher {
// ($T:expr) => ({
// #[allow(unused_imports)]
// use $crate::catcher::types::Resolve;
//
// Resolve::new($T).cast()
// })
// }

// pub use resolve_typed_catcher;

pub mod resolution {
use std::marker::PhantomData;

use transient::{CanTranscendTo, Transient};

use super::*;

/// The *magic*.
///
/// `Resolve<T>::item` for `T: Transient` is `<T as Transient>::item`.
/// `Resolve<T>::item` for `T: !Transient` is `DefaultTypeErase::item`.
///
/// This _must_ be used as `Resolve::<T>:item` for resolution to work. This
/// is a fun, static dispatch hack for "specialization" that works because
/// Rust prefers inherent methods over blanket trait impl methods.
pub struct Resolve<'r, T: 'r>(T, PhantomData<&'r ()>);

impl<'r, T: 'r> Resolve<'r, T> {
pub fn new(val: T) -> Self {
Self(val, PhantomData)
}
}

/// Fallback trait "implementing" `Transient` for all types. This is what
/// Rust will resolve `Resolve<T>::item` to when `T: !Transient`.
pub trait DefaultTypeErase<'r>: Sized {
const SPECIALIZED: bool = false;

fn cast(self) -> ErasedError<'r> { Box::new(()) }
}

impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {}

/// "Specialized" "implementation" of `Transient` for `T: Transient`. This is
/// what Rust will resolve `Resolve<T>::item` to when `T: Transient`.
impl<'r, T: Transient + Send + Sync + 'r> Resolve<'r, T>
where T::Transience: CanTranscendTo<Co<'r>>
{
pub const SPECIALIZED: bool = true;

pub fn cast(self) -> ErasedError<'r> { Box::new(self.0) }
}
}

#[cfg(test)]
mod test {
// use std::any::TypeId;

use transient::{Transient, TypeId};

use super::resolution::{Resolve, DefaultTypeErase};

struct NotAny;
#[derive(Transient)]
struct YesAny;

#[test]
fn check_can_determine() {
let not_any = Resolve::new(NotAny).cast();
assert_eq!(not_any.type_id(), TypeId::of::<()>());

let yes_any = Resolve::new(YesAny).cast();
assert_ne!(yes_any.type_id(), TypeId::of::<()>());
}

// struct HasSentinel<T>(T);

// #[test]
// fn parent_works() {
// let child = resolve!(YesASentinel, HasSentinel<YesASentinel>);
// assert!(child.type_name.ends_with("YesASentinel"));
// assert_eq!(child.parent.unwrap(), TypeId::of::<HasSentinel<YesASentinel>>());
// assert!(child.specialized);

// let not_a_direct_sentinel = resolve!(HasSentinel<YesASentinel>);
// assert!(not_a_direct_sentinel.type_name.contains("HasSentinel"));
// assert!(not_a_direct_sentinel.type_name.contains("YesASentinel"));
// assert!(not_a_direct_sentinel.parent.is_none());
// assert!(!not_a_direct_sentinel.specialized);
// }
}
Loading

0 comments on commit 56e7fa6

Please sign in to comment.