diff --git a/.gitignore b/.gitignore index 57834d6..c2f4185 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ Cargo.lock .idea/ venv/ target/ +rust-toolchain.toml \ No newline at end of file diff --git a/example/derive_expression/expression_lib/Cargo.toml b/example/derive_expression/expression_lib/Cargo.toml index e3324b7..fb31c9a 100644 --- a/example/derive_expression/expression_lib/Cargo.toml +++ b/example/derive_expression/expression_lib/Cargo.toml @@ -13,3 +13,5 @@ polars = { workspace = true, features = ["fmt"], default-features = false } polars-plan = { workspace = true, default-features = false } pyo3 = { version = "0.19.0", features = ["extension-module"] } pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] } +serde = {version = "1", features = ["derive"] } +jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/example/derive_expression/expression_lib/expression_lib/__init__.py b/example/derive_expression/expression_lib/expression_lib/__init__.py index 8d3d4ff..c16991a 100644 --- a/example/derive_expression/expression_lib/expression_lib/__init__.py +++ b/example/derive_expression/expression_lib/expression_lib/__init__.py @@ -23,7 +23,6 @@ def append_args( integer_arg: int, string_arg: str, boolean_arg: bool, - dict_arg: dict, ) -> pl.Expr: """ This example shows how arguments other than `Series` can be used. @@ -36,7 +35,6 @@ def append_args( "integer_arg": integer_arg, "string_arg": string_arg, "boolean_arg": boolean_arg, - "dict_arg": dict_arg, }, symbol="append_kwargs", is_elementwise=True, diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index fb5eb38..a142b4f 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -1,6 +1,7 @@ use polars::prelude::*; use polars_plan::dsl::FieldsMapper; -use pyo3_polars::derive::{polars_expr, Kwargs}; +use pyo3_polars::derive::{polars_expr, DefaultKwargs}; +use serde::Deserialize; use std::fmt::Write; fn pig_latin_str(value: &str, output: &mut String) { @@ -10,21 +11,21 @@ fn pig_latin_str(value: &str, output: &mut String) { } #[polars_expr(output_type=Utf8)] -fn pig_latinnify(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn pig_latinnify(inputs: &[Series], _kwargs: Option) -> PolarsResult { let ca = inputs[0].utf8()?; let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); Ok(out.into_series()) } #[polars_expr(output_type=Float64)] -fn jaccard_similarity(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn jaccard_similarity(inputs: &[Series], _kwargs: Option) -> PolarsResult { let a = inputs[0].list()?; let b = inputs[1].list()?; crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series()) } #[polars_expr(output_type=Float64)] -fn hamming_distance(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn hamming_distance(inputs: &[Series], _kwargs: Option) -> PolarsResult { let a = inputs[0].utf8()?; let b = inputs[1].utf8()?; let out: UInt32Chunked = @@ -37,7 +38,7 @@ fn haversine_output(input_fields: &[Field]) -> PolarsResult { } #[polars_expr(type_func=haversine_output)] -fn haversine(inputs: &[Series], _kwargs: Option) -> PolarsResult { +fn haversine(inputs: &[Series], _kwargs: Option) -> PolarsResult { let out = match inputs[0].dtype() { DataType::Float32 => { let start_lat = inputs[0].f32().unwrap(); @@ -60,41 +61,21 @@ fn haversine(inputs: &[Series], _kwargs: Option) -> PolarsResult Ok(out) } -fn map_err(msg: &str) -> PolarsError { - polars_err!(ComputeError: "{msg}") +/// The `DefaultKwargs` isn't very ergonomic as it doesn't validate any schema. +/// Provide your own kwargs struct with the proper schema and accept that type +/// in your plugin expression. +#[derive(Deserialize)] +pub struct MyKwargs { + float_arg: f64, + integer_arg: i64, + string_arg: String, + boolean_arg: bool, } #[polars_expr(output_type=Utf8)] -fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult { +fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult { let input = &input[0]; - let kwargs = kwargs.ok_or_else(|| polars_err!(ComputeError: "expected kwargs"))?; - - let float_arg = kwargs - .get("float_arg") - .ok_or_else(|| map_err("expected 'float_arg'"))? - .as_f64() - .ok_or_else(|| map_err("expected float"))?; - let integer_arg = kwargs - .get("integer_arg") - .ok_or_else(|| map_err("expected 'integer_arg'"))? - .as_i64() - .ok_or_else(|| map_err("expected integer"))?; - let string_arg = kwargs - .get("string_arg") - .ok_or_else(|| map_err("expected 'string_arg'"))? - .as_str() - .ok_or_else(|| map_err("expected string"))?; - let boolean_arg = kwargs - .get("boolean_arg") - .ok_or_else(|| map_err("expected 'boolean_arg'"))? - .as_bool() - .ok_or_else(|| map_err("expected boolean"))?; - let dict_arg = kwargs - .get("dict_arg") - .ok_or_else(|| map_err("expected 'dict_arg'"))? - .as_object() - .ok_or_else(|| map_err("expected dict"))?; - + let kwargs = kwargs.unwrap(); let input = input.cast(&DataType::Utf8)?; let ca = input.utf8().unwrap(); @@ -102,8 +83,8 @@ fn append_kwargs(input: &[Series], kwargs: Option) -> PolarsResult proc_macro2::TokenStream { #[no_mangle] pub unsafe extern "C" fn #fn_name ( e: *mut polars_ffi::SeriesExport, - len: usize, - kwargs: *const std::os::raw::c_char, + input_len: usize, + kwargs_ptr: *const u8, + kwargs_len: usize, return_value: *mut polars_ffi::SeriesExport ) { - let inputs = polars_ffi::import_series_buffer(e, len).unwrap(); - let kwargs = std::ffi::CStr::from_ptr(kwargs).to_bytes(); + let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap(); + + let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len); let kwargs = if kwargs.is_empty() { ::std::option::Option::None } else { - let value = pyo3_polars::derive::_parse_kwargs(kwargs); - ::std::option::Option::Some(value) + match pyo3_polars::derive::_parse_kwargs(kwargs) { + Ok(value) => Some(value), + Err(err) => { + pyo3_polars::derive::_update_last_error(err); + return; + } + } }; // define the function diff --git a/pyo3-polars/Cargo.toml b/pyo3-polars/Cargo.toml index d9d43ea..5bf4910 100644 --- a/pyo3-polars/Cargo.toml +++ b/pyo3-polars/Cargo.toml @@ -19,8 +19,10 @@ polars-plan = { workspace = true, optional = true } pyo3 = "0.19.0" pyo3-polars-derive = { version = "0.1.0", path = "../pyo3-polars-derive", optional = true } thiserror = "1" -serde_json = {version = "1.0", optional = true } +serde-pickle = {version = "1", optional = true } +serde = {version = "1", optional = true} + [features] lazy = ["polars/serde-lazy", "polars-plan", "polars-lazy/serde", "ciborium"] -derive = ["pyo3-polars-derive", "polars-plan", "polars-ffi", "serde_json"] +derive = ["pyo3-polars-derive", "polars-plan", "polars-ffi", "serde-pickle", "serde"] diff --git a/pyo3-polars/src/derive.rs b/pyo3-polars/src/derive.rs index 1f8703e..405bc28 100644 --- a/pyo3-polars/src/derive.rs +++ b/pyo3-polars/src/derive.rs @@ -1,23 +1,21 @@ use polars::prelude::PolarsError; +use polars_core::error::{to_compute_err, PolarsResult}; pub use pyo3_polars_derive::polars_expr; -pub use serde_json; -pub use serde_json::{Map, Value}; +use serde::Deserialize; use std::cell::RefCell; use std::ffi::CString; -pub type Kwargs = serde_json::Map; + +pub type DefaultKwargs = serde_pickle::Value; thread_local! { static LAST_ERROR: RefCell = RefCell::new(CString::default()); } -pub unsafe fn _parse_kwargs(kwargs: &[u8]) -> Kwargs { - let kwargs = std::str::from_utf8_unchecked(kwargs); - let value = serde_json::from_str(kwargs).unwrap(); - if let Value::Object(kwargs) = value { - return kwargs; - } else { - panic!("expected kwargs dictionary") - } +pub unsafe fn _parse_kwargs<'a, T>(kwargs: &'a [u8]) -> PolarsResult +where + T: Deserialize<'a>, +{ + serde_pickle::from_slice(kwargs, Default::default()).map_err(to_compute_err) } pub fn _update_last_error(err: PolarsError) { diff --git a/pyo3-polars/src/error.rs b/pyo3-polars/src/error.rs index 7dea7ba..9e38cf6 100644 --- a/pyo3-polars/src/error.rs +++ b/pyo3-polars/src/error.rs @@ -1,7 +1,6 @@ use std::fmt::{Debug, Formatter}; use polars::prelude::PolarsError; -use polars_core::error::ArrowError; use pyo3::create_exception; use pyo3::exceptions::{PyException, PyIOError, PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -13,8 +12,6 @@ pub enum PyPolarsErr { Polars(#[from] PolarsError), #[error("{0}")] Other(String), - #[error(transparent)] - Arrow(#[from] ArrowError), } impl std::convert::From for PyErr { @@ -31,7 +28,6 @@ impl std::convert::From for PyErr { PolarsError::Io(err) => PyIOError::new_err(err.to_string()), PolarsError::OutOfBounds(err) => PyIndexError::new_err(err.to_string()), PolarsError::InvalidOperation(err) => PyValueError::new_err(err.to_string()), - PolarsError::ArrowError(err) => ArrowErrorException::new_err(format!("{:?}", err)), PolarsError::Duplicate(err) => DuplicateError::new_err(err.to_string()), PolarsError::ColumnNotFound(err) => ColumnNotFound::new_err(err.to_string()), PolarsError::SchemaFieldNotFound(err) => { @@ -44,7 +40,6 @@ impl std::convert::From for PyErr { StringCacheMismatchError::new_err(err.to_string()) } }, - Arrow(err) => ArrowErrorException::new_err(format!("{:?}", err)), _ => default(), } } @@ -56,7 +51,6 @@ impl Debug for PyPolarsErr { match self { Polars(err) => write!(f, "{:?}", err), Other(err) => write!(f, "BindingsError: {:?}", err), - Arrow(err) => write!(f, "{:?}", err), } } } @@ -66,7 +60,6 @@ create_exception!(exceptions, SchemaFieldNotFound, PyException); create_exception!(exceptions, StructFieldNotFound, PyException); create_exception!(exceptions, ComputeError, PyException); create_exception!(exceptions, NoDataError, PyException); -create_exception!(exceptions, ArrowErrorException, PyException); create_exception!(exceptions, ShapeError, PyException); create_exception!(exceptions, SchemaError, PyException); create_exception!(exceptions, DuplicateError, PyException); diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..4a5741c --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2023-10-12"