From fca88bb030bfc8107530a2a95fd4f9abfebd5b8a Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 17 Oct 2023 07:48:32 +0200 Subject: [PATCH 1/4] add logical type example --- Cargo.toml | 16 +++++++++++----- .../derive_expression/expression_lib/Cargo.toml | 2 +- .../expression_lib/expression_lib/__init__.py | 13 +++++++++++++ .../expression_lib/src/expressions.rs | 14 ++++++++++++++ .../derive_expression/expression_lib/src/lib.rs | 3 +++ example/derive_expression/run.py | 3 +++ 6 files changed, 45 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 53918fb..c68c0f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,14 @@ members = [ ] [workspace.dependencies] -polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false } -polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +#polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +#polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +#polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +#polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false } +#polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } + +polars = { path = "../polars/crates/polars", version = "0.33.2", default-features = false } +polars-core = { path = "../polars/crates/polars-core", version = "0.33.2", default-features = false } +polars-ffi = { path = "../polars/crates/polars-ffi", version = "0.33.2", default-features = false } +polars-plan = { path = "../polars/crates/polars-plan", version = "0.33.2", default-feautres = false } +polars-lazy = { path = "../polars/crates/polars-lazy", version = "0.33.2", default-features = false } diff --git a/example/derive_expression/expression_lib/Cargo.toml b/example/derive_expression/expression_lib/Cargo.toml index d1a4ac2..9c85b1f 100644 --- a/example/derive_expression/expression_lib/Cargo.toml +++ b/example/derive_expression/expression_lib/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } -polars = { workspace = true, features = ["fmt"], default-features = false } +polars = { workspace = true, features = ["fmt", "dtype-date"], default-features = false } polars-plan = { workspace = true, default-features = false } pyo3 = { version = "0.20.0", features = ["extension-module"] } pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] } diff --git a/example/derive_expression/expression_lib/expression_lib/__init__.py b/example/derive_expression/expression_lib/expression_lib/__init__.py index c16991a..9b3be74 100644 --- a/example/derive_expression/expression_lib/expression_lib/__init__.py +++ b/example/derive_expression/expression_lib/expression_lib/__init__.py @@ -76,3 +76,16 @@ def haversine( is_elementwise=True, cast_to_supertypes=True, ) + +@pl.api.register_expr_namespace("date_util") +class DateUtil: + def __init__(self, expr: pl.Expr): + self._expr = expr + + + def is_leap_year(self) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + symbol="is_leap_year", + is_elementwise=True, + ) diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index a142b4f..4854ffe 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -90,3 +90,17 @@ fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult) -> PolarsResult { + let input = &input[0]; + let ca = input.date()?; + + let out: BooleanChunked = ca.as_date_iter().map(|opt_dt| { + opt_dt.map(|dt| { + dt.leap_year() + }) + }).collect_ca(ca.name()); + + Ok(out.into_series()) +} diff --git a/example/derive_expression/expression_lib/src/lib.rs b/example/derive_expression/expression_lib/src/lib.rs index beffb7b..d5c6766 100644 --- a/example/derive_expression/expression_lib/src/lib.rs +++ b/example/derive_expression/expression_lib/src/lib.rs @@ -1,6 +1,9 @@ mod distances; mod expressions; +#[cfg(target_os = "linux")] +use jemallocator::Jemalloc; + #[global_allocator] #[cfg(target_os = "linux")] static ALLOC: Jemalloc = Jemalloc; diff --git a/example/derive_expression/run.py b/example/derive_expression/run.py index 52d0192..92c094b 100644 --- a/example/derive_expression/run.py +++ b/example/derive_expression/run.py @@ -1,10 +1,12 @@ import polars as pl from expression_lib import Language, Distance +from datetime import date df = pl.DataFrame( { "names": ["Richard", "Alice", "Bob"], "moons": ["full", "half", "red"], + "dates": [date(2023, 1, 1), date(2024, 1, 1), date(2025, 1, 1)], "dist_a": [[12, 32, 1], [], [1, -2]], "dist_b": [[-12, 1], [43], [876, -45, 9]], "floats": [5.6, -1245.8, 242.224], @@ -18,6 +20,7 @@ hamming_dist=pl.col("names").dist.hamming_distance("pig_latin"), jaccard_sim=pl.col("dist_a").dist.jaccard_similarity("dist_b"), haversine=pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"), + leap_year=pl.col("dates").date_util.is_leap_year(), appended_args=pl.col("names").language.append_args( float_arg=11.234, integer_arg=93, From 5096b3c256678d9fff760cf6f5f5bc5ea230d189 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 17 Oct 2023 07:49:23 +0200 Subject: [PATCH 2/4] fmt --- Cargo.toml | 10 +++++----- .../expression_lib/src/expressions.rs | 9 ++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c68c0f5..e508e12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,11 @@ members = [ ] [workspace.dependencies] -#polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -#polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -#polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -#polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false } -#polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +# polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +# polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +# polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } +# polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false } +# polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } polars = { path = "../polars/crates/polars", version = "0.33.2", default-features = false } polars-core = { path = "../polars/crates/polars-core", version = "0.33.2", default-features = false } diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index 4854ffe..ba03bf5 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -96,11 +96,10 @@ fn is_leap_year(input: &[Series], _kwargs: Option) -> PolarsResul let input = &input[0]; let ca = input.date()?; - let out: BooleanChunked = ca.as_date_iter().map(|opt_dt| { - opt_dt.map(|dt| { - dt.leap_year() - }) - }).collect_ca(ca.name()); + let out: BooleanChunked = ca + .as_date_iter() + .map(|opt_dt| opt_dt.map(|dt| dt.leap_year())) + .collect_ca(ca.name()); Ok(out.into_series()) } From 551c2b559aa46bdf777c63cf6c99ae6e8e584af2 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 17 Oct 2023 08:30:52 +0200 Subject: [PATCH 3/4] automatically pass kwargs when needed --- Cargo.toml | 20 ++-- .../expression_lib/src/expressions.rs | 18 ++-- pyo3-polars-derive/src/lib.rs | 101 +++++++++++++----- 3 files changed, 92 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e508e12..2110bd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,14 +8,14 @@ members = [ ] [workspace.dependencies] -# polars = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -# polars-core = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -# polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } -# polars-plan = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-feautres = false } -# polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "5d48cc800bc9c71fe6d4ff97b96d7fed4601793b", version = "0.33.2", default-features = false } + polars = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false } + polars-core = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false } + polars-ffi = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false } + polars-plan = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-feautres = false } + polars-lazy = { git = "https://github.com/pola-rs/polars", rev = "d00a43203b3ade009a5f858f4c698b6a50f5b1e6", version = "0.33.2", default-features = false } -polars = { path = "../polars/crates/polars", version = "0.33.2", default-features = false } -polars-core = { path = "../polars/crates/polars-core", version = "0.33.2", default-features = false } -polars-ffi = { path = "../polars/crates/polars-ffi", version = "0.33.2", default-features = false } -polars-plan = { path = "../polars/crates/polars-plan", version = "0.33.2", default-feautres = false } -polars-lazy = { path = "../polars/crates/polars-lazy", version = "0.33.2", default-features = false } +#polars = { path = "../polars/crates/polars", version = "0.33.2", default-features = false } +#polars-core = { path = "../polars/crates/polars-core", version = "0.33.2", default-features = false } +#polars-ffi = { path = "../polars/crates/polars-ffi", version = "0.33.2", default-features = false } +#polars-plan = { path = "../polars/crates/polars-plan", version = "0.33.2", default-feautres = false } +#polars-lazy = { path = "../polars/crates/polars-lazy", version = "0.33.2", default-features = false } diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index ba03bf5..e0ff2ac 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -1,6 +1,6 @@ use polars::prelude::*; use polars_plan::dsl::FieldsMapper; -use pyo3_polars::derive::{polars_expr, DefaultKwargs}; +use pyo3_polars::derive::polars_expr; use serde::Deserialize; use std::fmt::Write; @@ -11,21 +11,21 @@ fn pig_latin_str(value: &str, output: &mut String) { } #[polars_expr(output_type=Utf8)] -fn pig_latinnify(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn pig_latinnify(inputs: &[Series]) -> PolarsResult { let ca = inputs[0].utf8()?; let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); Ok(out.into_series()) } #[polars_expr(output_type=Float64)] -fn jaccard_similarity(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn jaccard_similarity(inputs: &[Series]) -> PolarsResult { let a = inputs[0].list()?; let b = inputs[1].list()?; crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series()) } #[polars_expr(output_type=Float64)] -fn hamming_distance(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn hamming_distance(inputs: &[Series]) -> PolarsResult { let a = inputs[0].utf8()?; let b = inputs[1].utf8()?; let out: UInt32Chunked = @@ -38,7 +38,7 @@ fn haversine_output(input_fields: &[Field]) -> PolarsResult { } #[polars_expr(type_func=haversine_output)] -fn haversine(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn haversine(inputs: &[Series]) -> PolarsResult { let out = match inputs[0].dtype() { DataType::Float32 => { let start_lat = inputs[0].f32().unwrap(); @@ -72,10 +72,12 @@ pub struct MyKwargs { boolean_arg: bool, } +/// If you want to accept `kwargs`. You define a `kwargs` argument +/// on the second position in you plugin. You can provide any custom struct that is deserializable +/// with the pickle protocol (on the rust side). #[polars_expr(output_type=Utf8)] -fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult { +fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { let input = &input[0]; - let kwargs = kwargs.unwrap(); let input = input.cast(&DataType::Utf8)?; let ca = input.utf8().unwrap(); @@ -92,7 +94,7 @@ fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult) -> PolarsResult { +fn is_leap_year(input: &[Series]) -> PolarsResult { let input = &input[0]; let ca = input.date()?; diff --git a/pyo3-polars-derive/src/lib.rs b/pyo3-polars-derive/src/lib.rs index aa58f7f..ba2ea2c 100644 --- a/pyo3-polars-derive/src/lib.rs +++ b/pyo3-polars-derive/src/lib.rs @@ -4,7 +4,7 @@ mod keywords; use proc_macro::TokenStream; use quote::quote; use std::sync::atomic::{AtomicBool, Ordering}; -use syn::parse_macro_input; +use syn::{parse_macro_input, FnArg}; static INIT: AtomicBool = AtomicBool::new(false); @@ -21,10 +21,79 @@ fn insert_error_function() -> proc_macro2::TokenStream { } } +fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { + quote!( + + let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len); + + let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) { + Ok(value) => value, + Err(err) => { + pyo3_polars::derive::_update_last_error(err); + return; + } + }; + + // define the function + #ast + + // call the function + let result: PolarsResult = #fn_name(&inputs, kwargs); + + ) +} + +fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { + quote!( + // define the function + #ast + // call the function + let result: PolarsResult = #fn_name(&inputs); + ) +} + +fn quote_process_results() -> proc_macro2::TokenStream { + quote!(match result { + Ok(out) => { + // Update return value. + *return_value = polars_ffi::export_series(&out); + } + Err(err) => { + // Set latest error, but leave return value in empty state. + pyo3_polars::derive::_update_last_error(err); + } + }) +} + fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { + // count how often the user define a kwargs argument. + let n_kwargs = ast + .sig + .inputs + .iter() + .filter(|fn_arg| { + if let FnArg::Typed(pat) = fn_arg { + if let syn::Pat::Ident(pat) = pat.pat.as_ref() { + pat.ident.to_string() == "kwargs" + } else { + false + } + } else { + true + } + }) + .count(); + let fn_name = &ast.sig.ident; let error_msg_fn = insert_error_function(); + let quote_call = match n_kwargs { + 0 => quote_call_no_kwargs(&ast, fn_name), + 1 => quote_call_kwargs(&ast, fn_name), + _ => panic!("expected 0 or 1 kwargs, got {}", n_kwargs), + }; + let quote_process_result = quote_process_results(); + quote!( use pyo3_polars::export::*; @@ -41,35 +110,9 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { ) { let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap(); - let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len); - - let kwargs = if kwargs.is_empty() { - ::std::option::Option::None - } else { - match pyo3_polars::derive::_parse_kwargs(kwargs) { - Ok(value) => Some(value), - Err(err) => { - pyo3_polars::derive::_update_last_error(err); - return; - } - } - }; - - // define the function - #ast + #quote_call - // call the function - let result: PolarsResult = #fn_name(&inputs, kwargs); - match result { - Ok(out) => { - // Update return value. - *return_value = polars_ffi::export_series(&out); - }, - Err(err) => { - // Set latest error, but leave return value in empty state. - pyo3_polars::derive::_update_last_error(err); - } - } + #quote_process_result } ) } From 075b5309bd90a3648e69d8bcb987354466574d14 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 17 Oct 2023 08:42:30 +0200 Subject: [PATCH 4/4] unreachable --- pyo3-polars-derive/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyo3-polars-derive/src/lib.rs b/pyo3-polars-derive/src/lib.rs index ba2ea2c..3f4eea2 100644 --- a/pyo3-polars-derive/src/lib.rs +++ b/pyo3-polars-derive/src/lib.rs @@ -90,7 +90,7 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { let quote_call = match n_kwargs { 0 => quote_call_no_kwargs(&ast, fn_name), 1 => quote_call_kwargs(&ast, fn_name), - _ => panic!("expected 0 or 1 kwargs, got {}", n_kwargs), + _ => unreachable!(), // arguments are unique }; let quote_process_result = quote_process_results();