diff --git a/examples/server_fns_axum/src/app.rs b/examples/server_fns_axum/src/app.rs index 75f5c8fe79..8423a587b9 100644 --- a/examples/server_fns_axum/src/app.rs +++ b/examples/server_fns_axum/src/app.rs @@ -59,6 +59,7 @@ pub fn HomePage() -> impl IntoView { view! {

"Some Simple Server Functions"

+

"Custom Error Types"

@@ -75,6 +76,44 @@ pub fn HomePage() -> impl IntoView { } } +/// Server functions can be made generic, which will register multiple endpoints. +/// +/// If you use generics, you need to explicitly register the server function endpoint for each type +/// with [`server_fn::axum::register_explicit`] or [`server_fn::actix::register_explicit`] +#[component] +pub fn Generic() -> impl IntoView { + use std::fmt::Display; + + #[server] + pub async fn test_fn(input: S) -> Result + where + S: Display, + { + // insert a simulated wait + tokio::time::sleep(std::time::Duration::from_millis(250)).await; + Ok(input.to_string()) + } + + view! { +

Generic Server Functions

+

"Server functions can be made generic, which will register multiple endpoints."

+

+ "If you use generics, you need to explicitly register the server function endpoint for each type." +

+

"Open your browser devtools to see which endpoints the function below calls."

+ + } +} + /// A server function is really just an API call to your server. But it provides a plain async /// function as a wrapper around that. This means you can call it like any other async code, just /// by spawning a task with `spawn_local`. @@ -382,7 +421,8 @@ pub fn FileUpload() -> impl IntoView {

{move || { - if upload_action.input_local().read().is_none() && upload_action.value().read().is_none() + if upload_action.input_local().read().is_none() + && upload_action.value().read().is_none() { "Upload a file.".to_string() } else if upload_action.pending().get() { @@ -929,13 +969,11 @@ pub fn PostcardExample() -> impl IntoView {

Using postcard encoding

"This example demonstrates using Postcard for efficient binary serialization."

+ set_input + .update(|data| { + data.age += 1; + }); + }>"Increment Age" // Display the current input data

"Input: " {move || format!("{:?}", input.get())}

diff --git a/leptos_macro/tests/server.rs b/leptos_macro/tests/server.rs index 2761da8cf2..92193c4776 100644 --- a/leptos_macro/tests/server.rs +++ b/leptos_macro/tests/server.rs @@ -14,7 +14,7 @@ pub mod tests { Ok(()) } assert_eq!( - ::PATH + ::url() .trim_end_matches(char::is_numeric), "/api/my_server_action" ); @@ -30,7 +30,7 @@ pub mod tests { pub async fn my_server_action() -> Result<(), ServerFnError> { Ok(()) } - assert_eq!(::PATH, "/foo/bar/my_path"); + assert_eq!(::url(), "/foo/bar/my_path"); assert_eq!( TypeId::of::<::InputEncoding>(), TypeId::of::() @@ -43,7 +43,7 @@ pub mod tests { pub async fn my_server_action() -> Result<(), ServerFnError> { Ok(()) } - assert_eq!(::PATH, "/foo/bar/my_path"); + assert_eq!(::url(), "/foo/bar/my_path"); assert_eq!( TypeId::of::<::InputEncoding>(), TypeId::of::() @@ -56,7 +56,7 @@ pub mod tests { pub async fn my_server_action() -> Result<(), ServerFnError> { Ok(()) } - assert_eq!(::PATH, "/api/my_path"); + assert_eq!(::url(), "/api/my_path"); assert_eq!( TypeId::of::<::InputEncoding>(), TypeId::of::() @@ -70,7 +70,7 @@ pub mod tests { Ok(()) } assert_eq!( - ::PATH.trim_end_matches(char::is_numeric), + ::url().trim_end_matches(char::is_numeric), "/api/my_server_action" ); assert_eq!( @@ -86,7 +86,7 @@ pub mod tests { Ok(()) } assert_eq!( - ::PATH + ::url() .trim_end_matches(char::is_numeric), "/foo/bar/my_server_action" ); @@ -103,7 +103,7 @@ pub mod tests { Ok(()) } assert_eq!( - ::PATH + ::url() .trim_end_matches(char::is_numeric), "/api/my_server_action" ); @@ -120,7 +120,7 @@ pub mod tests { Ok(()) } assert_eq!( - ::PATH, + ::url(), "/api/path/to/my/endpoint" ); assert_eq!( diff --git a/leptos_server/src/action.rs b/leptos_server/src/action.rs index 177fb9cff1..92b45bdf90 100644 --- a/leptos_server/src/action.rs +++ b/leptos_server/src/action.rs @@ -57,7 +57,7 @@ where #[track_caller] pub fn new() -> Self { let err = use_context::().and_then(|error| { - (error.path() == S::PATH) + (error.path() == S::url()) .then(|| ServerFnError::::de(error.err())) .map(Err) }); @@ -145,7 +145,7 @@ where /// Creates a new [`Action`] that will call the server function `S` when dispatched. pub fn new() -> Self { let err = use_context::().and_then(|error| { - (error.path() == S::PATH) + (error.path() == S::url()) .then(|| ServerFnError::::de(error.err())) .map(Err) }); diff --git a/server_fn/src/lib.rs b/server_fn/src/lib.rs index 5b656d33a1..8361fb9edb 100644 --- a/server_fn/src/lib.rs +++ b/server_fn/src/lib.rs @@ -67,9 +67,11 @@ //! ad hoc HTTP API endpoint, not a magic formula. Any server function can be accessed by any HTTP //! client. You should take care to sanitize any data being returned from the function to ensure it //! does not leak data that should exist only on the server. -//! - **Server functions can’t be generic.** Because each server function creates a separate API endpoint, -//! it is difficult to monomorphize. As a result, server functions cannot be generic (for now?) If you need to use -//! a generic function, you can define a generic inner function called by multiple concrete server functions. +//! - **Generic server fns must be explicitly registered with the type.** Each server function creates +//! a separate API endpoint, which means that the URL can change depending on the generic type. As a +//! result, server functions that are generic must be explicitly registered with the +//! [`axum::register_explicit`] or [`actix::register_explicit`] function call with your generic type +//! passed into it as an argument. //! - **Arguments and return types must be serializable.** We support a variety of different encodings, //! but one way or another arguments need to be serialized to be sent to the server and deserialized //! on the server, and the return type must be serialized on the server and deserialized on the client. @@ -191,9 +193,6 @@ where Self::Error, >, { - /// A unique path for the server function’s API endpoint, relative to the host, including its prefix. - const PATH: &'static str; - /// The type of the HTTP client that will send the request from the client side. /// /// For example, this might be `gloo-net` in the browser, or `reqwest` for a desktop app. @@ -226,10 +225,8 @@ where /// custom error type, this can be `NoCustomError` by default.) type Error: FromStr + Display; - /// Returns [`Self::PATH`]. - fn url() -> &'static str { - Self::PATH - } + /// A unique path for the server function’s API endpoint, relative to the host, including its prefix. + fn url() -> &'static str; /// Middleware that should be applied to this server function. fn middlewares( @@ -265,7 +262,7 @@ where .map(|res| (res, None)) .unwrap_or_else(|e| { ( - Self::ServerResponse::error_response(Self::PATH, &e), + Self::ServerResponse::error_response(Self::url(), &e), Some(e), ) }); @@ -275,7 +272,7 @@ where if accepts_html { // if it had an error, encode that error in the URL if let Some(err) = err { - if let Ok(url) = ServerFnUrlError::new(Self::PATH, err) + if let Ok(url) = ServerFnUrlError::new(Self::url(), err) .to_url(referer.as_deref().unwrap_or("/")) { referer = Some(url.to_string()); @@ -303,7 +300,7 @@ where async move { // create and send request on client let req = - self.into_req(Self::PATH, Self::OutputEncoding::CONTENT_TYPE)?; + self.into_req(Self::url(), Self::OutputEncoding::CONTENT_TYPE)?; Self::run_on_client_with_req(req, redirect::REDIRECT_HOOK.get()) .await } @@ -489,9 +486,9 @@ pub mod axum { > + 'static, { REGISTERED_SERVER_FUNCTIONS.insert( - (T::PATH.into(), T::InputEncoding::METHOD), + (T::url().into(), T::InputEncoding::METHOD), ServerFnTraitObj::new( - T::PATH, + T::url(), T::InputEncoding::METHOD, |req| Box::pin(T::run_on_server(req)), T::middlewares, @@ -577,9 +574,9 @@ pub mod actix { > + 'static, { REGISTERED_SERVER_FUNCTIONS.insert( - (T::PATH.into(), T::InputEncoding::METHOD), + (T::url().into(), T::InputEncoding::METHOD), ServerFnTraitObj::new( - T::PATH, + T::url(), T::InputEncoding::METHOD, |req| Box::pin(T::run_on_server(req)), T::middlewares, diff --git a/server_fn_macro/src/lib.rs b/server_fn_macro/src/lib.rs index 1e80222a8e..05423d564a 100644 --- a/server_fn_macro/src/lib.rs +++ b/server_fn_macro/src/lib.rs @@ -242,6 +242,99 @@ pub fn server_macro_impl( #server_fn_path::codec::Json } }); + + let (impl_generics, ty_generics, where_clause) = + body.generics.split_for_impl(); + let turbofish_ty_generics = ty_generics.as_turbofish(); + + // For the struct declaration, add a where clause where all the fields in the struct have a : Send + 'static bound + let struct_decl_where_clause = + where_clause.cloned().map(|mut where_clause| { + where_clause.predicates = where_clause + .predicates + .into_iter() + .map(|predicate| { + if let WherePredicate::Type(mut t) = predicate { + // Check if the type is used in the struct + let is_type_used = + body.inputs.iter().any(|f| match f { + FnArg::Receiver(_) => false, + FnArg::Typed(typed) => { + *typed.ty == t.bounded_ty + } + }); + + if is_type_used { + // If the type is used in the struct, add the bounds + t.bounds.push(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: None, + path: syn::parse_quote!(Send), + })); + t.bounds.push(TypeParamBound::Lifetime( + syn::parse_quote!('static), + )); + } + WherePredicate::Type(t) + } else { + predicate + } + }) + .collect(); + where_clause + }); + + // Add a `: Serialize + for<'leptos_lifetime_param> Deserialize<'leptos_lifetime_param> + Send + 'static` bound to all types that are used in the struct + let where_clause = where_clause.cloned().map(|mut where_clause| { + where_clause.predicates = where_clause + .predicates + .into_iter() + .map(|predicate| { + if let WherePredicate::Type(mut t) = predicate { + // Check if the type is used in the struct + let is_type_used = body.inputs.iter().any(|f| match f { + FnArg::Receiver(_) => false, + FnArg::Typed(typed) => *typed.ty == t.bounded_ty, + }); + + if is_type_used { + // If the type is used in the struct, add the bounds + t.bounds.push(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: None, + path: syn::parse_quote!( + #server_fn_path::serde::Serialize + ), + })); + t.bounds.push(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: Some(syn::parse_quote!(for<'leptos_param_lifetime>)), + path: syn::parse_quote!( + #server_fn_path::serde::Deserialize::<'leptos_param_lifetime> + ), + })); + t.bounds.push(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: None, + path: syn::parse_quote!(Send), + })); + t.bounds.push(TypeParamBound::Lifetime( + syn::parse_quote!('static), + )); + } + WherePredicate::Type(t) + } else { + predicate + } + }) + .collect(); + where_clause + }); + // default to PascalCase version of function name if no struct name given let struct_name = struct_name.unwrap_or_else(|| { let upper_camel_case_name = Converter::new() @@ -253,15 +346,15 @@ pub fn server_macro_impl( // struct name, wrapped in any custom-encoding newtype wrapper let wrapped_struct_name = if let Some(wrapper) = custom_wrapper.as_ref() { - quote! { #wrapper<#struct_name> } + quote! { #wrapper::<#struct_name #ty_generics> } } else { - quote! { #struct_name } + quote! { #struct_name #ty_generics } }; let wrapped_struct_name_turbofish = if let Some(wrapper) = custom_wrapper.as_ref() { - quote! { #wrapper::<#struct_name> } + quote! { #wrapper::<#struct_name #turbofish_ty_generics> } } else { - quote! { #struct_name } + quote! { #struct_name #turbofish_ty_generics } }; // build struct for type @@ -296,21 +389,22 @@ pub fn server_macro_impl( let impl_from = impl_from.map(|v| v.value).unwrap_or(true); let from_impl = (body.inputs.len() == 1 && first_field.is_some() + && body.generics.params.is_empty() && impl_from) .then(|| { let field = first_field.unwrap(); let (name, ty) = field; quote! { - impl From<#struct_name> for #ty { - fn from(value: #struct_name) -> Self { - let #struct_name { #name } = value; + impl #impl_generics From<#struct_name #ty_generics> for #ty #where_clause { + fn from(value: #struct_name #ty_generics) -> Self { + let #struct_name #turbofish_ty_generics { #name } = value; #name } } - impl From<#ty> for #struct_name { + impl #impl_generics From<#ty> for #struct_name #ty_generics #where_clause { fn from(#name: #ty) -> Self { - #struct_name { #name } + #struct_name #turbofish_ty_generics { #name } } } } @@ -363,34 +457,15 @@ pub fn server_macro_impl( .map(|(doc, span)| quote_spanned!(*span=> #[doc = #doc])) .collect::(); - // auto-registration with inventory - let inventory = if cfg!(feature = "ssr") { - quote! { - #server_fn_path::inventory::submit! {{ - use #server_fn_path::{ServerFn, codec::Encoding}; - #server_fn_path::ServerFnTraitObj::new( - #wrapped_struct_name_turbofish::PATH, - <#wrapped_struct_name as ServerFn>::InputEncoding::METHOD, - |req| { - Box::pin(#wrapped_struct_name_turbofish::run_on_server(req)) - }, - #wrapped_struct_name_turbofish::middlewares - ) - }} - } - } else { - quote! {} - }; - // run_body in the trait implementation - let run_body = if cfg!(feature = "ssr") { + let (run_body_lint_supressor, run_body) = if cfg!(feature = "ssr") { let destructure = if let Some(wrapper) = custom_wrapper.as_ref() { quote! { - let #wrapper(#struct_name { #(#field_names),* }) = self; + let #wrapper(#struct_name #turbofish_ty_generics { #(#field_names),* }) = self; } } else { quote! { - let #struct_name { #(#field_names),* } = self; + let #struct_name #turbofish_ty_generics { #(#field_names),* } = self; } }; @@ -406,32 +481,35 @@ pub fn server_macro_impl( #destructure #dummy_name(#(#field_names),*).await }; - let body = if cfg!(feature = "actix") { + ( quote! { - #server_fn_path::actix::SendWrapper::new(async move { + // we need this for Actix, for the SendWrapper to count as impl Future + // but non-Actix will have a clippy warning otherwise + #[allow(clippy::manual_async_fn)] + }, + if cfg!(feature = "actix") { + quote! { + #server_fn_path::actix::SendWrapper::new(async move { + #body + }) + } + } else { + quote! { async move { #body - }) - } - } else { - quote! { async move { - #body - }} - }; - quote! { - // we need this for Actix, for the SendWrapper to count as impl Future - // but non-Actix will have a clippy warning otherwise - #[allow(clippy::manual_async_fn)] - fn run_body(self) -> impl std::future::Future + Send { - #body - } - } + }} + }, + ) } else { - quote! { - #[allow(unused_variables)] - async fn run_body(self) -> #return_ty { - unreachable!() - } - } + ( + quote! { + #[allow(unused_variables, clippy::manual_async_fn)] + }, + quote! { + async move { + unreachable!() + } + }, + ) }; // the actual function definition @@ -439,7 +517,7 @@ pub fn server_macro_impl( quote! { #docs #(#attrs)* - #vis async fn #fn_name(#(#fn_args),*) #output_arrow #return_ty { + #vis async fn #fn_name #ty_generics (#(#fn_args),*) #output_arrow #return_ty #where_clause { #dummy_name(#(#field_names),*).await } } @@ -447,18 +525,18 @@ pub fn server_macro_impl( let restructure = if let Some(custom_wrapper) = custom_wrapper.as_ref() { quote! { - let data = #custom_wrapper(#struct_name { #(#field_names),* }); + let data = #custom_wrapper(#struct_name #turbofish_ty_generics { #(#field_names),* }); } } else { quote! { - let data = #struct_name { #(#field_names),* }; + let data = #struct_name #turbofish_ty_generics { #(#field_names),* }; } }; quote! { #docs #(#attrs)* #[allow(unused_variables)] - #vis async fn #fn_name(#(#fn_args),*) #output_arrow #return_ty { + #vis async fn #fn_name #ty_generics (#(#fn_args),*) #output_arrow #return_ty #where_clause { use #server_fn_path::ServerFn; #restructure data.run_on_client().await @@ -628,20 +706,39 @@ pub fn server_macro_impl( quote! { vec![] } }; + // auto-registration with inventory only if there are no generics + let inventory = if cfg!(feature = "ssr") + && ty_generics.clone().into_token_stream().is_empty() + { + quote! { + #server_fn_path::inventory::submit! {{ + use #server_fn_path::{ServerFn, codec::Encoding}; + #server_fn_path::ServerFnTraitObj::new( + #path, + #input::METHOD, + |req| { + Box::pin(#wrapped_struct_name_turbofish::run_on_server(req)) + }, + #wrapped_struct_name_turbofish::middlewares + ) + }} + } + } else { + quote! {} + }; + Ok(quote::quote! { #args_docs #docs #[derive(Debug, #derives)] #addl_path - pub struct #struct_name { + pub struct #struct_name #ty_generics #struct_decl_where_clause { #(#fields),* } #from_impl - impl #server_fn_path::ServerFn for #wrapped_struct_name { - const PATH: &'static str = #path; - + impl #impl_generics #server_fn_path::ServerFn for #wrapped_struct_name #where_clause { type Client = #client; type ServerRequest = #req; type ServerResponse = #res; @@ -650,11 +747,18 @@ pub fn server_macro_impl( type OutputEncoding = #output; type Error = #error_ty; + fn url() -> &'static str { + #path + } + fn middlewares() -> Vec>> { #middlewares } - #run_body + #run_body_lint_supressor + fn run_body(self) -> impl std::future::Future + Send { + #run_body + } } #inventory @@ -1045,7 +1149,7 @@ impl Parse for ServerFnBody { let fn_token = input.parse()?; let ident = input.parse()?; - let generics: Generics = input.parse()?; + let mut generics = input.parse::()?; let content; let _paren_token = syn::parenthesized!(content in input); @@ -1054,6 +1158,7 @@ impl Parse for ServerFnBody { let output_arrow = input.parse()?; let return_ty = input.parse()?; + generics.where_clause = input.parse()?; let block = input.parse()?; @@ -1121,10 +1226,11 @@ impl ServerFnBody { block, .. } = &self; + let where_clause = generics.where_clause.as_ref(); quote! { #[doc(hidden)] #(#attrs)* - #vis #async_token #fn_token #ident #generics ( #inputs ) #output_arrow #return_ty + #vis #async_token #fn_token #ident #generics ( #inputs ) #output_arrow #return_ty #where_clause #block } }