Skip to content

Commit

Permalink
Ensure examples pass CI
Browse files Browse the repository at this point in the history
  • Loading branch information
the10thWiz committed Jun 30, 2024
1 parent 4cb3a3a commit ac3a7fa
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 26 deletions.
17 changes: 12 additions & 5 deletions core/codegen/src/attribute/catch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@ use crate::http_codegen::Optional;
use crate::syn_ext::ReturnTypeExt;
use crate::exports::*;

fn arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
fn error_arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
match arg {
syn::FnArg::Receiver(_) => Err(Diagnostic::spanned(
arg.span(),
Level::Error,
"Catcher cannot have self as a parameter"
"Catcher cannot have self as a parameter",
)),
syn::FnArg::Typed(syn::PatType {ty, ..})=> Ok(ty.as_ref()),
syn::FnArg::Typed(syn::PatType { ty, .. }) => match ty.as_ref() {
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(elem.as_ref()),
_ => Err(Diagnostic::spanned(
ty.span(),
Level::Error,
"Error type must be a reference",
)),
},
}
}

Expand Down Expand Up @@ -58,9 +65,9 @@ pub fn _catch(
}).rev();
let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 {
let arg = catch.function.sig.inputs.first().unwrap();
let ty = arg_ty(arg)?;
let ty = error_arg_ty(arg)?;
(quote_spanned!(arg.span() =>
let #__error = match ::rocket::catcher::downcast(__error_init.as_ref()) {
let #__error: &#ty = match ::rocket::catcher::downcast(__error_init.as_ref()) {
Some(v) => v,
None => return #_Result::Err((#__status, __error_init)),
};
Expand Down
34 changes: 33 additions & 1 deletion core/lib/src/outcome.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
//! a type of `Option<S>`. If an `Outcome` is a `Forward`, the `Option` will be
//! `None`.
use crate::catcher::default_error_type;
use transient::{CanTranscendTo, Co, Transient};

use crate::catcher::{default_error_type, ErasedError};
use crate::{route, request, response};
use crate::data::{self, Data, FromData};
use crate::http::Status;
Expand Down Expand Up @@ -809,3 +811,33 @@ impl<'r, 'o: 'r> IntoOutcome<route::Outcome<'r>> for response::Result<'o> {
}
}
}

type RoutedOutcome<'r, T> = Outcome<
T,
(Status, ErasedError<'r>),
(Data<'r>, Status, ErasedError<'r>)
>;

impl<'r, T, E: Transient> IntoOutcome<RoutedOutcome<'r, T>> for Option<Result<T, E>>
where E::Transience: CanTranscendTo<Co<'r>>,
E: Send + Sync + 'r,
{
type Error = Status;
type Forward = (Data<'r>, Status);

fn or_error(self, error: Self::Error) -> RoutedOutcome<'r, T> {
match self {
Some(Ok(v)) => Outcome::Success(v),
Some(Err(e)) => Outcome::Error((error, Box::new(e))),
None => Outcome::Error((error, default_error_type())),
}
}

fn or_forward(self, forward: Self::Forward) -> RoutedOutcome<'r, T> {
match self {
Some(Ok(v)) => Outcome::Success(v),
Some(Err(e)) => Outcome::Forward((forward.0, forward.1, Box::new(e))),
None => Outcome::Forward((forward.0, forward.1, default_error_type())),
}
}
}
7 changes: 4 additions & 3 deletions examples/error-handling/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use rocket::{Rocket, Request, Build};
use rocket::response::{content, status};
use rocket::http::{Status, uri::error::PathError};
use rocket::http::Status;

// Custom impl so I can implement Static (or Transient) ---
// We should upstream implementations for most common error types
Expand All @@ -13,6 +13,7 @@ use rocket::catcher::{Static};
use std::num::ParseIntError;

#[derive(Debug)]
#[allow(unused)]
struct IntErr(ParseIntError);
impl Static for IntErr {}

Expand Down Expand Up @@ -45,7 +46,7 @@ fn general_not_found() -> content::RawHtml<&'static str> {
}

#[catch(404)]
fn hello_not_found(s: Status, req: &Request<'_>) -> content::RawHtml<String> {
fn hello_not_found(req: &Request<'_>) -> content::RawHtml<String> {
content::RawHtml(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
Expand All @@ -57,7 +58,7 @@ fn hello_not_found(s: Status, req: &Request<'_>) -> content::RawHtml<String> {
// be present. I'm thinking about adding a param to the macro to indicate which (and whether)
// param is a downcast error.
#[catch(422)]
fn param_error(e: &IntErr, s: Status, req: &Request<'_>) -> content::RawHtml<String> {
fn param_error(e: &IntErr, _s: Status, req: &Request<'_>) -> content::RawHtml<String> {
content::RawHtml(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>\
Expand Down
11 changes: 8 additions & 3 deletions examples/error-handling/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use rocket::local::blocking::Client;
use rocket::http::Status;
use super::{I8, IntErr};

#[test]
fn test_hello() {
Expand All @@ -10,7 +11,7 @@ fn test_hello() {
let response = client.get(uri).dispatch();

assert_eq!(response.status(), Status::Ok);
assert_eq!(response.into_string().unwrap(), super::hello(name, age));
assert_eq!(response.into_string().unwrap(), super::hello(name, I8(age)));
}

#[test]
Expand Down Expand Up @@ -48,10 +49,14 @@ fn test_hello_invalid_age() {

for path in &["Ford/-129", "Trillian/128"] {
let request = client.get(format!("/hello/{}", path));
let expected = super::default_catcher(Status::UnprocessableEntity, request.inner());
let expected = super::param_error(
&IntErr(path.split_once("/").unwrap().1.parse::<i8>().unwrap_err()),
Status::UnprocessableEntity,
request.inner()
);
let response = request.dispatch();
assert_eq!(response.status(), Status::UnprocessableEntity);
assert_eq!(response.into_string().unwrap(), expected.1);
assert_eq!(response.into_string().unwrap(), expected.0);
}

{
Expand Down
40 changes: 26 additions & 14 deletions examples/manual-routing/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#[cfg(test)]
mod tests;

use rocket::{Request, Route, Catcher, route, catcher};
use rocket::{Request, Route, Catcher, route, catcher, outcome::Outcome};
use rocket::data::{Data, ToByteUnit};
use rocket::http::{Status, Method::{Get, Post}};
use rocket::response::{Responder, status::Custom};
use rocket::outcome::{try_outcome, IntoOutcome};
use rocket::tokio::fs::File;

fn forward<'r>(_req: &'r Request, data: Data<'r>) -> route::BoxFuture<'r> {
Expand All @@ -25,12 +24,17 @@ fn name<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
}

fn echo_url<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
let param_outcome = req.param::<&str>(1)
.and_then(Result::ok)
.or_error(Status::BadRequest);

Box::pin(async move {
route::Outcome::from(req, try_outcome!(param_outcome))
let param_outcome = match req.param::<&str>(1) {
Some(Ok(v)) => v,
Some(Err(e)) => return Outcome::Error((
Status::BadRequest,
Box::new(e) as catcher::ErasedError
)),
None => return Outcome::Error((Status::BadRequest, catcher::default_error_type())),
};

route::Outcome::from(req, param_outcome)
})
}

Expand Down Expand Up @@ -62,9 +66,11 @@ fn get_upload<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
route::Outcome::from(req, std::fs::File::open(path).ok()).pin()
}

fn not_found_handler<'r>(_: Status, req: &'r Request) -> catcher::BoxFuture<'r> {
fn not_found_handler<'r>(_: Status, req: &'r Request, _e: catcher::ErasedError<'r>)
-> catcher::BoxFuture<'r>
{
let responder = Custom(Status::NotFound, format!("Couldn't find: {}", req.uri()));
Box::pin(async move { responder.respond_to(req) })
Box::pin(async move { responder.respond_to(req).map_err(|s| (s, _e)) })
}

#[derive(Clone)]
Expand All @@ -82,11 +88,17 @@ impl CustomHandler {
impl route::Handler for CustomHandler {
async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> {
let self_data = self.data;
let id = req.param::<&str>(0)
.and_then(Result::ok)
.or_forward((data, Status::NotFound));

route::Outcome::from(req, format!("{} - {}", self_data, try_outcome!(id)))
let id = match req.param::<&str>(1) {
Some(Ok(v)) => v,
Some(Err(e)) => return Outcome::Forward((data, Status::BadRequest, Box::new(e))),
None => return Outcome::Forward((
data,
Status::BadRequest,
catcher::default_error_type()
)),
};

route::Outcome::from(req, format!("{} - {}", self_data, id))
}
}

Expand Down

0 comments on commit ac3a7fa

Please sign in to comment.