From 599c2bb2d4ba3a91ff8e01cc27c6860488ef7f7e Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 26 Nov 2024 20:23:01 +0100 Subject: [PATCH] Introduce a `#[declare_sql_function]` attribute macro This commit introduces a new `#[declare_sql_function]` attribute macro that can be applied to `extern "SQL"` blocks. This is essentially the same as the existing `define_sql_function!` function like macro in terms of functionality. I see the following advantages of using an attribute macro + an `extern "SQL"` block instead: * This is closer to rust syntax, so rustfmt will understand that and work correctly inside these blocks * This allows to put several functions into the same block * Maybe in the future this also allows to apply attributes to the whole block instead of to each item The downside of this change is that we then have three variants to declare sql functions: * `sql_function!()` (deprectated) * `define_sql_function!()` (introduced in 2.2, we might want to deprecate that as well?) * The new attribute macro --- .../expression/functions/aggregate_folding.rs | 7 +- diesel/src/expression/functions/mod.rs | 2 + diesel_derives/src/lib.rs | 363 ++++++++++++++++++ diesel_derives/src/sql_function.rs | 27 ++ 4 files changed, 395 insertions(+), 4 deletions(-) diff --git a/diesel/src/expression/functions/aggregate_folding.rs b/diesel/src/expression/functions/aggregate_folding.rs index 25fcc8da051f..081cbfed8605 100644 --- a/diesel/src/expression/functions/aggregate_folding.rs +++ b/diesel/src/expression/functions/aggregate_folding.rs @@ -1,7 +1,8 @@ -use crate::expression::functions::define_sql_function; +use crate::expression::functions::declare_sql_function; use crate::sql_types::Foldable; -define_sql_function! { +#[declare_sql_function] +extern "SQL" { /// Represents a SQL `SUM` function. This function can only take types which are /// Foldable. /// @@ -19,9 +20,7 @@ define_sql_function! { /// ``` #[aggregate] fn sum(expr: ST) -> ST::Sum; -} -define_sql_function! { /// Represents a SQL `AVG` function. This function can only take types which are /// Foldable. /// diff --git a/diesel/src/expression/functions/mod.rs b/diesel/src/expression/functions/mod.rs index db8f79e7a730..34c41d9d646f 100644 --- a/diesel/src/expression/functions/mod.rs +++ b/diesel/src/expression/functions/mod.rs @@ -1,5 +1,7 @@ //! Helper macros to define custom sql functions +#[doc(inline)] +pub use diesel_derives::declare_sql_function; #[doc(inline)] pub use diesel_derives::define_sql_function; diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs index 3b53aca5527b..17c5900f8d7e 100644 --- a/diesel_derives/src/lib.rs +++ b/diesel_derives/src/lib.rs @@ -23,6 +23,7 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; +use sql_function::ExternSqlBlock; use syn::{parse_macro_input, parse_quote}; mod attrs; @@ -1912,3 +1913,365 @@ pub fn auto_type( const AUTO_TYPE_DEFAULT_METHOD_TYPE_CASE: dsl_auto_type::Case = dsl_auto_type::Case::UpperCamel; const AUTO_TYPE_DEFAULT_FUNCTION_TYPE_CASE: dsl_auto_type::Case = dsl_auto_type::Case::DoNotChange; + +/// Declare a sql function for use in your code. +/// +/// Diesel only provides support for a very small number of SQL functions. +/// This macro enables you to add additional functions from the SQL standard, +/// as well as any custom functions your application might have. +/// +/// The syntax for this attribute macro is designed to be applied to `extern "SQL"` blocks +/// with function definitions. These function typically use types +/// from [`diesel::sql_types`](../diesel/sql_types/index.html) as arguments and return types. +/// You can use such definitions to declare bindings to unsupported SQL functions. +/// +/// For each function in this `extern` block the macro will generate two items. +/// A function with the name that you've given, and a module with a helper type +/// representing the return type of your function. For example, this invocation: +/// +/// ```ignore +/// #[declare_sql_function] +/// extern "SQL" { +/// fn lower(x: Text) -> Text +/// } +/// ``` +/// +/// will generate this code: +/// +/// ```ignore +/// pub fn lower(x: X) -> lower { +/// ... +/// } +/// +/// pub type lower = ...; +/// ``` +/// +/// Most attributes given to this macro will be put on the generated function +/// (including doc comments). +/// +/// # Adding Doc Comments +/// +/// ```no_run +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # table! { crates { id -> Integer, name -> VarChar, } } +/// # +/// use diesel::sql_types::Text; +/// +/// #[declare_sql_function] +/// extern "SQL" { +/// /// Represents the `canon_crate_name` SQL function, created in +/// /// migration .... +/// fn canon_crate_name(a: Text) -> Text; +/// } +/// +/// # fn main() { +/// # use self::crates::dsl::*; +/// let target_name = "diesel"; +/// crates.filter(canon_crate_name(name).eq(canon_crate_name(target_name))); +/// // This will generate the following SQL +/// // SELECT * FROM crates WHERE canon_crate_name(crates.name) = canon_crate_name($1) +/// # } +/// ``` +/// +/// # Special Attributes +/// +/// There are a handful of special attributes that Diesel will recognize. They +/// are: +/// +/// - `#[aggregate]` +/// - Indicates that this is an aggregate function, and that `NonAggregate` +/// shouldn't be implemented. +/// - `#[sql_name = "name"]` +/// - The SQL to be generated is different from the Rust name of the function. +/// This can be used to represent functions which can take many argument +/// types, or to capitalize function names. +/// +/// Functions can also be generic. Take the definition of `sum`, for example: +/// +/// ```no_run +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # table! { crates { id -> Integer, name -> VarChar, } } +/// # +/// use diesel::sql_types::Foldable; +/// +/// #[declare_sql_function] +/// extern "SQL" { +/// #[aggregate] +/// #[sql_name = "SUM"] +/// fn sum(expr: ST) -> ST::Sum; +/// } +/// +/// # fn main() { +/// # use self::crates::dsl::*; +/// crates.select(sum(id)); +/// # } +/// ``` +/// +/// # SQL Functions without Arguments +/// +/// A common example is ordering a query using the `RANDOM()` sql function, +/// which can be implemented using `define_sql_function!` like this: +/// +/// ```rust +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # table! { crates { id -> Integer, name -> VarChar, } } +/// # +/// #[declare_sql_function] +/// extern "SQL" { +/// fn random() -> Text; +/// } +/// +/// # fn main() { +/// # use self::crates::dsl::*; +/// crates.order(random()); +/// # } +/// ``` +/// +/// # Use with SQLite +/// +/// On most backends, the implementation of the function is defined in a +/// migration using `CREATE FUNCTION`. On SQLite, the function is implemented in +/// Rust instead. You must call `register_impl` or +/// `register_nondeterministic_impl` (in the generated function's `_internals` +/// module) with every connection before you can use the function. +/// +/// These functions will only be generated if the `sqlite` feature is enabled, +/// and the function is not generic. +/// SQLite doesn't support generic functions and variadic functions. +/// +/// ```rust +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # #[cfg(feature = "sqlite")] +/// # fn main() { +/// # run_test().unwrap(); +/// # } +/// # +/// # #[cfg(not(feature = "sqlite"))] +/// # fn main() { +/// # } +/// # +/// use diesel::sql_types::{Integer, Double}; +/// +/// #[declare_sql_function] +/// extern "SQL" { +/// fn add_mul(x: Integer, y: Integer, z: Double) -> Double; +/// } +/// +/// # #[cfg(feature = "sqlite")] +/// # fn run_test() -> Result<(), Box> { +/// let connection = &mut SqliteConnection::establish(":memory:")?; +/// +/// add_mul_utils::register_impl(connection, |x: i32, y: i32, z: f64| { +/// (x + y) as f64 * z +/// })?; +/// +/// let result = select(add_mul(1, 2, 1.5)) +/// .get_result::(connection)?; +/// assert_eq!(4.5, result); +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Panics +/// +/// If an implementation of the custom function panics and unwinding is enabled, the panic is +/// caught and the function returns to libsqlite with an error. It can't propagate the panics due +/// to the FFI boundary. +/// +/// This is the same for [custom aggregate functions](#custom-aggregate-functions). +/// +/// ## Custom Aggregate Functions +/// +/// Custom aggregate functions can be created in SQLite by adding an `#[aggregate]` +/// attribute inside `define_sql_function`. `register_impl` (in the generated function's `_utils` +/// module) needs to be called with a type implementing the +/// [SqliteAggregateFunction](../diesel/sqlite/trait.SqliteAggregateFunction.html) +/// trait as a type parameter as shown in the examples below. +/// +/// ```rust +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # #[cfg(feature = "sqlite")] +/// # fn main() { +/// # run().unwrap(); +/// # } +/// # +/// # #[cfg(not(feature = "sqlite"))] +/// # fn main() { +/// # } +/// use diesel::sql_types::Integer; +/// # #[cfg(feature = "sqlite")] +/// use diesel::sqlite::SqliteAggregateFunction; +/// +/// #[declare_sql_function] +/// extern "SQL" { +/// #[aggregate] +/// fn my_sum(x: Integer) -> Integer; +/// } +/// +/// #[derive(Default)] +/// struct MySum { sum: i32 } +/// +/// # #[cfg(feature = "sqlite")] +/// impl SqliteAggregateFunction for MySum { +/// type Output = i32; +/// +/// fn step(&mut self, expr: i32) { +/// self.sum += expr; +/// } +/// +/// fn finalize(aggregator: Option) -> Self::Output { +/// aggregator.map(|a| a.sum).unwrap_or_default() +/// } +/// } +/// # table! { +/// # players { +/// # id -> Integer, +/// # score -> Integer, +/// # } +/// # } +/// +/// # #[cfg(feature = "sqlite")] +/// fn run() -> Result<(), Box> { +/// # use self::players::dsl::*; +/// let connection = &mut SqliteConnection::establish(":memory:")?; +/// # diesel::sql_query("create table players (id integer primary key autoincrement, score integer)") +/// # .execute(connection) +/// # .unwrap(); +/// # diesel::sql_query("insert into players (score) values (10), (20), (30)") +/// # .execute(connection) +/// # .unwrap(); +/// +/// my_sum_utils::register_impl::(connection)?; +/// +/// let total_score = players.select(my_sum(score)) +/// .get_result::(connection)?; +/// +/// println!("The total score of all the players is: {}", total_score); +/// +/// # assert_eq!(60, total_score); +/// Ok(()) +/// } +/// ``` +/// +/// With multiple function arguments, the arguments are passed as a tuple to `SqliteAggregateFunction` +/// +/// ```rust +/// # extern crate diesel; +/// # use diesel::*; +/// # use diesel::expression::functions::declare_sql_function; +/// # +/// # #[cfg(feature = "sqlite")] +/// # fn main() { +/// # run().unwrap(); +/// # } +/// # +/// # #[cfg(not(feature = "sqlite"))] +/// # fn main() { +/// # } +/// use diesel::sql_types::{Float, Nullable}; +/// # #[cfg(feature = "sqlite")] +/// use diesel::sqlite::SqliteAggregateFunction; +/// +/// #[declare_sql_function] +/// extern "SQL" { +/// #[aggregate] +/// fn range_max(x0: Float, x1: Float) -> Nullable; +/// } +/// +/// #[derive(Default)] +/// struct RangeMax { max_value: Option } +/// +/// # #[cfg(feature = "sqlite")] +/// impl SqliteAggregateFunction<(T, T)> for RangeMax { +/// type Output = Option; +/// +/// fn step(&mut self, (x0, x1): (T, T)) { +/// # let max = if x0 >= x1 { +/// # x0 +/// # } else { +/// # x1 +/// # }; +/// # +/// # self.max_value = match self.max_value { +/// # Some(current_max_value) if max > current_max_value => Some(max), +/// # None => Some(max), +/// # _ => self.max_value, +/// # }; +/// // Compare self.max_value to x0 and x1 +/// } +/// +/// fn finalize(aggregator: Option) -> Self::Output { +/// aggregator?.max_value +/// } +/// } +/// # table! { +/// # student_avgs { +/// # id -> Integer, +/// # s1_avg -> Float, +/// # s2_avg -> Float, +/// # } +/// # } +/// +/// # #[cfg(feature = "sqlite")] +/// fn run() -> Result<(), Box> { +/// # use self::student_avgs::dsl::*; +/// let connection = &mut SqliteConnection::establish(":memory:")?; +/// # diesel::sql_query("create table student_avgs (id integer primary key autoincrement, s1_avg float, s2_avg float)") +/// # .execute(connection) +/// # .unwrap(); +/// # diesel::sql_query("insert into student_avgs (s1_avg, s2_avg) values (85.5, 90), (79.8, 80.1)") +/// # .execute(connection) +/// # .unwrap(); +/// +/// range_max_utils::register_impl::, _, _>(connection)?; +/// +/// let result = student_avgs.select(range_max(s1_avg, s2_avg)) +/// .get_result::>(connection)?; +/// +/// if let Some(max_semester_avg) = result { +/// println!("The largest semester average is: {}", max_semester_avg); +/// } +/// +/// # assert_eq!(Some(90f32), result); +/// Ok(()) +/// } +/// ``` +#[proc_macro_attribute] +pub fn declare_sql_function( + _attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = proc_macro2::TokenStream::from(input); + let result = syn::parse2::(input.clone()).map(|res| { + let expanded = res + .function_decls + .into_iter() + .map(|decl| sql_function::expand(decl, false)); + quote::quote! { + #(#expanded)* + } + }); + match result { + Ok(token_stream) => token_stream.into(), + Err(e) => { + let mut output = input; + output.extend(e.into_compile_error()); + output.into() + } + } +} diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs index c35a3fe95930..5ed628522c80 100644 --- a/diesel_derives/src/sql_function.rs +++ b/diesel_derives/src/sql_function.rs @@ -3,6 +3,7 @@ use quote::quote; use quote::ToTokens; use syn::parse::{Parse, ParseStream, Result}; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{ parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue, PathArguments, Token, Type, @@ -437,6 +438,32 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool } } +pub(crate) struct ExternSqlBlock { + pub(crate) function_decls: Vec, +} + +impl Parse for ExternSqlBlock { + fn parse(input: ParseStream) -> Result { + let block = syn::ItemForeignMod::parse(input)?; + if block.abi.name.as_ref().map(|n| n.value()) != Some("SQL".into()) { + return Err(syn::Error::new(block.abi.span(), "expect `SQL` as ABI")); + } + if block.unsafety.is_some() { + return Err(syn::Error::new( + block.unsafety.unwrap().span(), + "expect `SQL` function blocks to be safe", + )); + } + let function_decls = block + .items + .into_iter() + .map(|i| syn::parse2(quote! { #i })) + .collect::>>()?; + + Ok(ExternSqlBlock { function_decls }) + } +} + pub(crate) struct SqlFunctionDecl { attributes: Vec, fn_token: Token![fn],