From 50bc628a2645b71d4b422f6b2b5fe0f4b6085abd Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Thu, 16 Jul 2020 10:51:02 -0700 Subject: [PATCH 1/2] Server: require State to be Clone Alternative to https://github.com/http-rs/tide/pull/642 This approach is more flexible but requires the user ensure that their state implements/derives `Clone`, or is wrapped in an `Arc`. Co-authored-by: Jacob Rothstein --- Cargo.toml | 1 + examples/graphql.rs | 7 +++-- examples/middleware.rs | 4 +-- examples/upload.rs | 8 +++--- src/cookies/middleware.rs | 2 +- src/endpoint.rs | 8 +++--- src/fs/serve_dir.rs | 6 ++--- src/lib.rs | 3 ++- src/listener/concurrent_listener.rs | 4 +-- src/listener/failover_listener.rs | 4 +-- src/listener/parsed_listener.rs | 2 +- src/listener/tcp_listener.rs | 4 +-- src/listener/to_listener.rs | 38 +++++++++++++------------- src/listener/unix_listener.rs | 4 +-- src/log/middleware.rs | 4 +-- src/middleware.rs | 4 +-- src/redirect.rs | 2 +- src/request.rs | 41 ++++++++++++++--------------- src/route.rs | 8 +++--- src/router.rs | 10 ++++--- src/security/cors.rs | 2 +- src/server.rs | 13 ++++----- src/sse/endpoint.rs | 6 ++--- src/sse/upgrade.rs | 2 +- src/utils.rs | 4 +-- tests/route_middleware.rs | 2 +- tests/test_utils.rs | 2 +- 27 files changed, 103 insertions(+), 92 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e1c547c3c..8cfa09bd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ route-recognizer = "0.2.0" logtest = "2.0.0" async-trait = "0.1.36" futures-util = "0.3.5" +pin-project-lite = "0.1.7" [dev-dependencies] async-std = { version = "1.6.0", features = ["unstable", "attributes"] } diff --git a/examples/graphql.rs b/examples/graphql.rs index c065bddb0..a7662f655 100644 --- a/examples/graphql.rs +++ b/examples/graphql.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_std::task; use juniper::{http::graphiql, http::GraphQLRequest, RootNode}; use std::sync::RwLock; @@ -37,8 +39,9 @@ impl NewUser { } } +#[derive(Clone)] pub struct State { - users: RwLock>, + users: Arc>>, } impl juniper::Context for State {} @@ -96,7 +99,7 @@ async fn handle_graphiql(_: Request) -> tide::Result> fn main() -> std::io::Result<()> { task::block_on(async { let mut app = Server::with_state(State { - users: RwLock::new(Vec::new()), + users: Arc::new(RwLock::new(Vec::new())), }); app.at("/").get(Redirect::permanent("/graphiql")); app.at("/graphql").post(handle_graphql); diff --git a/examples/middleware.rs b/examples/middleware.rs index 18544848a..b12fbf263 100644 --- a/examples/middleware.rs +++ b/examples/middleware.rs @@ -11,7 +11,7 @@ struct User { name: String, } -#[derive(Default, Debug)] +#[derive(Clone, Default, Debug)] struct UserDatabase; impl UserDatabase { async fn find_user(&self) -> Option { @@ -62,7 +62,7 @@ impl RequestCounterMiddleware { struct RequestCount(usize); #[tide::utils::async_trait] -impl Middleware for RequestCounterMiddleware { +impl Middleware for RequestCounterMiddleware { async fn handle(&self, mut req: Request, next: Next<'_, State>) -> Result { let count = self.requests_counted.fetch_add(1, Ordering::Relaxed); tide::log::trace!("request counter", { count: count }); diff --git a/examples/upload.rs b/examples/upload.rs index e5fa00b3c..5bc42e862 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_std::{fs::OpenOptions, io}; use tempfile::TempDir; use tide::prelude::*; @@ -6,7 +8,7 @@ use tide::{Body, Request, Response, StatusCode}; #[async_std::main] async fn main() -> Result<(), std::io::Error> { tide::log::start(); - let mut app = tide::with_state(tempfile::tempdir()?); + let mut app = tide::with_state(Arc::new(tempfile::tempdir()?)); // To test this example: // $ cargo run --example upload @@ -14,7 +16,7 @@ async fn main() -> Result<(), std::io::Error> { // $ curl localhost:8080/README.md # this reads the file from the same temp directory app.at(":file") - .put(|req: Request| async move { + .put(|req: Request>| async move { let path: String = req.param("file")?; let fs_path = req.state().path().join(path); @@ -33,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> { Ok(json!({ "bytes": bytes_written })) }) - .get(|req: Request| async move { + .get(|req: Request>| async move { let path: String = req.param("file")?; let fs_path = req.state().path().join(path); diff --git a/src/cookies/middleware.rs b/src/cookies/middleware.rs index 227b13d00..cc9473cf7 100644 --- a/src/cookies/middleware.rs +++ b/src/cookies/middleware.rs @@ -35,7 +35,7 @@ impl CookiesMiddleware { } #[async_trait] -impl Middleware for CookiesMiddleware { +impl Middleware for CookiesMiddleware { async fn handle(&self, mut ctx: Request, next: Next<'_, State>) -> crate::Result { let cookie_jar = if let Some(cookie_data) = ctx.ext::() { cookie_data.content.clone() diff --git a/src/endpoint.rs b/src/endpoint.rs index e505b78e3..0fecf0b9b 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -45,7 +45,7 @@ use crate::{Middleware, Request, Response}; /// /// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics. #[async_trait] -pub trait Endpoint: Send + Sync + 'static { +pub trait Endpoint: Send + Sync + 'static { /// Invoke the endpoint within the given context async fn call(&self, req: Request) -> crate::Result; } @@ -55,7 +55,7 @@ pub(crate) type DynEndpoint = dyn Endpoint; #[async_trait] impl Endpoint for F where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Send + Sync + 'static + Fn(Request) -> Fut, Fut: Future> + Send + 'static, Res: Into + 'static, @@ -93,7 +93,7 @@ impl std::fmt::Debug for MiddlewareEndpoint { impl MiddlewareEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, E: Endpoint, { pub fn wrap_with_middleware(ep: E, middleware: &[Arc>]) -> Self { @@ -107,7 +107,7 @@ where #[async_trait] impl Endpoint for MiddlewareEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, E: Endpoint, { async fn call(&self, req: Request) -> crate::Result { diff --git a/src/fs/serve_dir.rs b/src/fs/serve_dir.rs index a79e008b7..1dff4f46c 100644 --- a/src/fs/serve_dir.rs +++ b/src/fs/serve_dir.rs @@ -21,7 +21,7 @@ impl ServeDir { #[async_trait::async_trait] impl Endpoint for ServeDir where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, { async fn call(&self, req: Request) -> Result { let path = req.url().path(); @@ -60,8 +60,6 @@ where mod test { use super::*; - use async_std::sync::Arc; - use std::fs::{self, File}; use std::io::Write; @@ -83,7 +81,7 @@ mod test { let request = crate::http::Request::get( crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(), ); - crate::Request::new(Arc::new(()), request, vec![]) + crate::Request::new((), request, vec![]) } #[async_std::test] diff --git a/src/lib.rs b/src/lib.rs index 8f560208e..07cd85aab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -259,6 +259,7 @@ pub fn new() -> server::Server<()> { /// use tide::Request; /// /// /// The shared application state. +/// #[derive(Clone)] /// struct State { /// name: String, /// } @@ -279,7 +280,7 @@ pub fn new() -> server::Server<()> { /// ``` pub fn with_state(state: State) -> server::Server where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, { Server::with_state(state) } diff --git a/src/listener/concurrent_listener.rs b/src/listener/concurrent_listener.rs index 7e5e586a3..1cff9b0d7 100644 --- a/src/listener/concurrent_listener.rs +++ b/src/listener/concurrent_listener.rs @@ -35,7 +35,7 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; #[derive(Default)] pub struct ConcurrentListener(Vec>>); -impl ConcurrentListener { +impl ConcurrentListener { /// creates a new ConcurrentListener pub fn new() -> Self { Self(vec![]) @@ -78,7 +78,7 @@ impl ConcurrentListener { } #[async_trait::async_trait] -impl Listener for ConcurrentListener { +impl Listener for ConcurrentListener { async fn listen(&mut self, app: Server) -> io::Result<()> { let mut futures_unordered = FuturesUnordered::new(); diff --git a/src/listener/failover_listener.rs b/src/listener/failover_listener.rs index ac37a7c10..4ab1bd242 100644 --- a/src/listener/failover_listener.rs +++ b/src/listener/failover_listener.rs @@ -35,7 +35,7 @@ use async_std::io; #[derive(Default)] pub struct FailoverListener(Vec>>); -impl FailoverListener { +impl FailoverListener { /// creates a new FailoverListener pub fn new() -> Self { Self(vec![]) @@ -80,7 +80,7 @@ impl FailoverListener { } #[async_trait::async_trait] -impl Listener for FailoverListener { +impl Listener for FailoverListener { async fn listen(&mut self, app: Server) -> io::Result<()> { for listener in self.0.iter_mut() { let app = app.clone(); diff --git a/src/listener/parsed_listener.rs b/src/listener/parsed_listener.rs index c98dc98ef..4b0e186a9 100644 --- a/src/listener/parsed_listener.rs +++ b/src/listener/parsed_listener.rs @@ -31,7 +31,7 @@ impl Display for ParsedListener { } #[async_trait::async_trait] -impl Listener for ParsedListener { +impl Listener for ParsedListener { async fn listen(&mut self, app: Server) -> io::Result<()> { match self { #[cfg(unix)] diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index f6364fa3d..db68530ee 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -51,7 +51,7 @@ impl TcpListener { } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp(app: Server, stream: TcpStream) { task::spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -69,7 +69,7 @@ fn handle_tcp(app: Server, stream: TcpStrea } #[async_trait::async_trait] -impl Listener for TcpListener { +impl Listener for TcpListener { async fn listen(&mut self, app: Server) -> io::Result<()> { self.connect().await?; let listener = self.listener()?; diff --git a/src/listener/to_listener.rs b/src/listener/to_listener.rs index 461960864..e5efe244d 100644 --- a/src/listener/to_listener.rs +++ b/src/listener/to_listener.rs @@ -52,7 +52,7 @@ use std::net::ToSocketAddrs; /// # Other implementations /// See below for additional provided implementations of ToListener. -pub trait ToListener { +pub trait ToListener { type Listener: Listener; /// Transform self into a /// [`Listener`](crate::listener::Listener). Unless self is @@ -63,7 +63,7 @@ pub trait ToListener { fn to_listener(self) -> io::Result; } -impl ToListener for Url { +impl ToListener for Url { type Listener = ParsedListener; fn to_listener(self) -> io::Result { @@ -106,14 +106,14 @@ impl ToListener for Url { } } -impl ToListener for String { +impl ToListener for String { type Listener = ParsedListener; fn to_listener(self) -> io::Result { ToListener::::to_listener(self.as_str()) } } -impl ToListener for &str { +impl ToListener for &str { type Listener = ParsedListener; fn to_listener(self) -> io::Result { @@ -133,7 +133,7 @@ impl ToListener for &str { } #[cfg(unix)] -impl ToListener for async_std::path::PathBuf { +impl ToListener for async_std::path::PathBuf { type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) @@ -141,28 +141,28 @@ impl ToListener for async_std::path::PathBu } #[cfg(unix)] -impl ToListener for std::path::PathBuf { +impl ToListener for std::path::PathBuf { type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } -impl ToListener for async_std::net::TcpListener { +impl ToListener for async_std::net::TcpListener { type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for std::net::TcpListener { +impl ToListener for std::net::TcpListener { type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for (&str, u16) { +impl ToListener for (&str, u16) { type Listener = TcpListener; fn to_listener(self) -> io::Result { @@ -171,7 +171,9 @@ impl ToListener for (&str, u16) { } #[cfg(unix)] -impl ToListener for async_std::os::unix::net::UnixListener { +impl ToListener + for async_std::os::unix::net::UnixListener +{ type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) @@ -179,14 +181,14 @@ impl ToListener for async_std::os::unix::ne } #[cfg(unix)] -impl ToListener for std::os::unix::net::UnixListener { +impl ToListener for std::os::unix::net::UnixListener { type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } -impl ToListener for TcpListener { +impl ToListener for TcpListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) @@ -194,42 +196,42 @@ impl ToListener for TcpListener { } #[cfg(unix)] -impl ToListener for UnixListener { +impl ToListener for UnixListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ConcurrentListener { +impl ToListener for ConcurrentListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ParsedListener { +impl ToListener for ParsedListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for FailoverListener { +impl ToListener for FailoverListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for std::net::SocketAddr { +impl ToListener for std::net::SocketAddr { type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(vec![self])) } } -impl, State: Send + Sync + 'static> ToListener for Vec { +impl, State: Clone + Send + Sync + 'static> ToListener for Vec { type Listener = ConcurrentListener; fn to_listener(self) -> io::Result { let mut concurrent_listener = ConcurrentListener::new(); diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index cea1559c4..72aff852d 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -64,7 +64,7 @@ fn unix_socket_addr_to_string(result: io::Result) -> Option }) } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix(app: Server, stream: UnixStream) { task::spawn(async move { let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -82,7 +82,7 @@ fn handle_unix(app: Server, stream: UnixStr } #[async_trait::async_trait] -impl Listener for UnixListener { +impl Listener for UnixListener { async fn listen(&mut self, app: Server) -> io::Result<()> { self.connect().await?; crate::log::info!("Server listening on {}", self); diff --git a/src/log/middleware.rs b/src/log/middleware.rs index 4e3581938..7ce87e723 100644 --- a/src/log/middleware.rs +++ b/src/log/middleware.rs @@ -24,7 +24,7 @@ impl LogMiddleware { } /// Log a request and a response. - async fn log<'a, State: Send + Sync + 'static>( + async fn log<'a, State: Clone + Send + Sync + 'static>( &'a self, ctx: Request, next: Next<'a, State>, @@ -75,7 +75,7 @@ impl LogMiddleware { } #[async_trait::async_trait] -impl Middleware for LogMiddleware { +impl Middleware for LogMiddleware { async fn handle(&self, req: Request, next: Next<'_, State>) -> crate::Result { self.log(req, next).await } diff --git a/src/middleware.rs b/src/middleware.rs index 29134ca01..7e1ca9a90 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -23,7 +23,7 @@ pub trait Middleware: Send + Sync + 'static { #[async_trait] impl Middleware for F where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Send + Sync + 'static @@ -44,7 +44,7 @@ pub struct Next<'a, State> { pub(crate) next_middleware: &'a [Arc>], } -impl Next<'_, State> { +impl Next<'_, State> { /// Asynchronously execute the remaining middleware chain. pub async fn run(mut self, req: Request) -> Response { if let Some((current, next)) = self.next_middleware.split_first() { diff --git a/src/redirect.rs b/src/redirect.rs index a9ae75e2d..0230279ec 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -88,7 +88,7 @@ impl> Redirect { #[async_trait::async_trait] impl Endpoint for Redirect where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, T: AsRef + Send + Sync + 'static, { async fn call(&self, _req: Request) -> crate::Result { diff --git a/src/request.rs b/src/request.rs index 332121c82..bb2cd83b2 100644 --- a/src/request.rs +++ b/src/request.rs @@ -4,7 +4,7 @@ use route_recognizer::Params; use std::ops::Index; use std::pin::Pin; -use std::{fmt, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr}; use crate::cookies::CookieData; use crate::http::cookies::Cookie; @@ -12,18 +12,21 @@ use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues}; use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version}; use crate::Response; -/// An HTTP request. -/// -/// The `Request` gives endpoints access to basic information about the incoming -/// request, route parameters, and various ways of accessing the request's body. -/// -/// Requests also provide *extensions*, a type map primarily used for low-level -/// communication between middleware and endpoints. -#[derive(Debug)] -pub struct Request { - pub(crate) state: Arc, - pub(crate) req: http::Request, - pub(crate) route_params: Vec, +pin_project_lite::pin_project! { + /// An HTTP request. + /// + /// The `Request` gives endpoints access to basic information about the incoming + /// request, route parameters, and various ways of accessing the request's body. + /// + /// Requests also provide *extensions*, a type map primarily used for low-level + /// communication between middleware and endpoints. + #[derive(Debug)] + pub struct Request { + pub(crate) state: State, + #[pin] + pub(crate) req: http::Request, + pub(crate) route_params: Vec, + } } #[derive(Debug)] @@ -45,11 +48,7 @@ impl std::error::Error for ParamError {} impl Request { /// Create a new `Request`. - pub(crate) fn new( - state: Arc, - req: http_types::Request, - route_params: Vec, - ) -> Self { + pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec) -> Self { Self { state, req, @@ -550,11 +549,11 @@ impl AsMut for Request { impl Read for Request { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - Pin::new(&mut self.req).poll_read(cx, buf) + self.project().req.poll_read(cx, buf) } } @@ -566,7 +565,7 @@ impl Into for Request { // NOTE: From cannot be implemented for this conversion because `State` needs to // be constrained by a type. -impl Into for Request { +impl Into for Request { fn into(mut self) -> Response { let mut res = Response::new(StatusCode::Ok); res.set_body(self.take_body()); diff --git a/src/route.rs b/src/route.rs index 32cbe486d..258f230f0 100644 --- a/src/route.rs +++ b/src/route.rs @@ -28,7 +28,7 @@ pub struct Route<'a, State> { prefix: bool, } -impl<'a, State: Send + Sync + 'static> Route<'a, State> { +impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { pub(crate) fn new(router: &'a mut Router, path: String) -> Route<'a, State> { Route { router, @@ -101,8 +101,8 @@ impl<'a, State: Send + Sync + 'static> Route<'a, State> { /// [`Server`]: struct.Server.html pub fn nest(&mut self, service: crate::Server) -> &mut Self where - State: Send + Sync + 'static, - InnerState: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, + InnerState: Clone + Send + Sync + 'static, { self.prefix = true; self.all(service); @@ -276,7 +276,7 @@ impl Clone for StripPrefixEndpoint { #[async_trait::async_trait] impl Endpoint for StripPrefixEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, E: Endpoint, { async fn call(&self, req: crate::Request) -> crate::Result { diff --git a/src/router.rs b/src/router.rs index bab4975be..b3673ee7d 100644 --- a/src/router.rs +++ b/src/router.rs @@ -21,7 +21,7 @@ pub struct Selection<'a, State> { pub(crate) params: Params, } -impl Router { +impl Router { pub fn new() -> Self { Router { method_map: HashMap::default(), @@ -81,10 +81,14 @@ impl Router { } } -async fn not_found_endpoint(_req: Request) -> crate::Result { +async fn not_found_endpoint( + _req: Request, +) -> crate::Result { Ok(Response::new(StatusCode::NotFound)) } -async fn method_not_allowed(_req: Request) -> crate::Result { +async fn method_not_allowed( + _req: Request, +) -> crate::Result { Ok(Response::new(StatusCode::MethodNotAllowed)) } diff --git a/src/security/cors.rs b/src/security/cors.rs index 65da7216e..12d0b4b02 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -133,7 +133,7 @@ impl CorsMiddleware { } #[async_trait::async_trait] -impl Middleware for CorsMiddleware { +impl Middleware for CorsMiddleware { async fn handle(&self, req: Request, next: Next<'_, State>) -> Result { // TODO: how should multiple origin values be handled? let origins = req.header(&headers::ORIGIN).cloned(); diff --git a/src/server.rs b/src/server.rs index 265e489fd..ae8cd1fdc 100644 --- a/src/server.rs +++ b/src/server.rs @@ -27,7 +27,7 @@ use crate::{Endpoint, Request, Route}; #[allow(missing_debug_implementations)] pub struct Server { router: Arc>, - state: Arc, + state: State, middleware: Arc>>>, } @@ -58,7 +58,7 @@ impl Default for Server<()> { } } -impl Server { +impl Server { /// Create a new Tide server with shared application scoped state. /// /// Application scoped state is useful for storing items @@ -72,6 +72,7 @@ impl Server { /// use tide::Request; /// /// /// The shared application state. + /// #[derive(Clone)] /// struct State { /// name: String, /// } @@ -94,7 +95,7 @@ impl Server { let mut server = Self { router: Arc::new(Router::new()), middleware: Arc::new(vec![]), - state: Arc::new(state), + state, }; server.middleware(cookies::CookiesMiddleware::new()); server.middleware(log::LogMiddleware::new()); @@ -228,7 +229,7 @@ impl Server { } } -impl Clone for Server { +impl Clone for Server { fn clone(&self) -> Self { Self { router: self.router.clone(), @@ -239,8 +240,8 @@ impl Clone for Server { } #[async_trait::async_trait] -impl Endpoint - for Server +impl + Endpoint for Server { async fn call(&self, req: Request) -> crate::Result { let Request { diff --git a/src/sse/endpoint.rs b/src/sse/endpoint.rs index 01aae2fdd..bf10462d7 100644 --- a/src/sse/endpoint.rs +++ b/src/sse/endpoint.rs @@ -13,7 +13,7 @@ use std::sync::Arc; /// Create an endpoint that can handle SSE connections. pub fn endpoint(handler: F) -> SseEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { @@ -28,7 +28,7 @@ where #[derive(Debug)] pub struct SseEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { @@ -40,7 +40,7 @@ where #[async_trait::async_trait] impl Endpoint for SseEndpoint where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { diff --git a/src/sse/upgrade.rs b/src/sse/upgrade.rs index 9eff4b864..020e15b89 100644 --- a/src/sse/upgrade.rs +++ b/src/sse/upgrade.rs @@ -11,7 +11,7 @@ use async_std::task; /// Upgrade an existing HTTP connection to an SSE connection. pub fn upgrade(req: Request, handler: F) -> Response where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { diff --git a/src/utils.rs b/src/utils.rs index 182c51a36..79fd3b160 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -27,7 +27,7 @@ pub struct Before(pub F); #[async_trait] impl Middleware for Before where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Request) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { @@ -61,7 +61,7 @@ pub struct After(pub F); #[async_trait] impl Middleware for After where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, F: Fn(Response) -> Fut + Send + Sync + 'static, Fut: Future + Send + Sync + 'static, { diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index dba61322f..7e78ee240 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -14,7 +14,7 @@ impl TestMiddleware { } #[async_trait::async_trait] -impl Middleware for TestMiddleware { +impl Middleware for TestMiddleware { async fn handle( &self, req: tide::Request, diff --git a/tests/test_utils.rs b/tests/test_utils.rs index d8c5d900d..cf2961e3a 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -21,7 +21,7 @@ pub trait ServerTestingExt { #[async_trait::async_trait] impl ServerTestingExt for Server where - State: Send + Sync + 'static, + State: Clone + Send + Sync + 'static, { async fn request(&self, method: Method, path: &str) -> http::Response { let url = if path.starts_with("http:") { From 30c244368fe40e3b77042906e3cd8bb1171e0794 Mon Sep 17 00:00:00 2001 From: Jeremiah Senkpiel Date: Thu, 16 Jul 2020 23:33:28 -0700 Subject: [PATCH 2/2] Examples: improve state handling in upload example Addresses https://github.com/http-rs/tide/pull/644#pullrequestreview-450220840 --- examples/upload.rs | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/upload.rs b/examples/upload.rs index 5bc42e862..5ac0e1d06 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -1,3 +1,5 @@ +use std::io::Error as IoError; +use std::path::Path; use std::sync::Arc; use async_std::{fs::OpenOptions, io}; @@ -5,10 +7,27 @@ use tempfile::TempDir; use tide::prelude::*; use tide::{Body, Request, Response, StatusCode}; +#[derive(Clone)] +struct TempDirState { + tempdir: Arc, +} + +impl TempDirState { + fn try_new() -> Result { + Ok(Self { + tempdir: Arc::new(tempfile::tempdir()?), + }) + } + + fn path(&self) -> &Path { + self.tempdir.path() + } +} + #[async_std::main] -async fn main() -> Result<(), std::io::Error> { +async fn main() -> Result<(), IoError> { tide::log::start(); - let mut app = tide::with_state(Arc::new(tempfile::tempdir()?)); + let mut app = tide::with_state(TempDirState::try_new()?); // To test this example: // $ cargo run --example upload @@ -16,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> { // $ curl localhost:8080/README.md # this reads the file from the same temp directory app.at(":file") - .put(|req: Request>| async move { + .put(|req: Request| async move { let path: String = req.param("file")?; let fs_path = req.state().path().join(path); @@ -35,7 +54,7 @@ async fn main() -> Result<(), std::io::Error> { Ok(json!({ "bytes": bytes_written })) }) - .get(|req: Request>| async move { + .get(|req: Request| async move { let path: String = req.param("file")?; let fs_path = req.state().path().join(path);