Skip to content

Commit

Permalink
pickle kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 16, 2023
1 parent 0148c07 commit 8ab4f14
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 69 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Cargo.lock
.idea/
venv/
target/
rust-toolchain.toml
2 changes: 2 additions & 0 deletions example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ polars = { workspace = true, features = ["fmt"], default-features = false }
polars-plan = { workspace = true, default-features = false }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
serde = {version = "1", features = ["derive"] }
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def append_args(
integer_arg: int,
string_arg: str,
boolean_arg: bool,
dict_arg: dict,
) -> pl.Expr:
"""
This example shows how arguments other than `Series` can be used.
Expand All @@ -36,7 +35,6 @@ def append_args(
"integer_arg": integer_arg,
"string_arg": string_arg,
"boolean_arg": boolean_arg,
"dict_arg": dict_arg,
},
symbol="append_kwargs",
is_elementwise=True,
Expand Down
57 changes: 19 additions & 38 deletions example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::{polars_expr, Kwargs};
use pyo3_polars::derive::{polars_expr, DefaultKwargs};
use serde::Deserialize;
use std::fmt::Write;

fn pig_latin_str(value: &str, output: &mut String) {
Expand All @@ -10,21 +11,21 @@ fn pig_latin_str(value: &str, output: &mut String) {
}

#[polars_expr(output_type=Utf8)]
fn pig_latinnify(inputs: &[Series], _kwargs: Option<Kwargs>) -> PolarsResult<Series> {
fn pig_latinnify(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let ca = inputs[0].utf8()?;
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str);
Ok(out.into_series())
}

#[polars_expr(output_type=Float64)]
fn jaccard_similarity(inputs: &[Series], _kwargs: Option<Kwargs>) -> PolarsResult<Series> {
fn jaccard_similarity(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let a = inputs[0].list()?;
let b = inputs[1].list()?;
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn hamming_distance(inputs: &[Series], _kwargs: Option<Kwargs>) -> PolarsResult<Series> {
fn hamming_distance(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let a = inputs[0].utf8()?;
let b = inputs[1].utf8()?;
let out: UInt32Chunked =
Expand All @@ -37,7 +38,7 @@ fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
}

#[polars_expr(type_func=haversine_output)]
fn haversine(inputs: &[Series], _kwargs: Option<Kwargs>) -> PolarsResult<Series> {
fn haversine(inputs: &[Series], _kwargs: Option<DefaultKwargs>) -> PolarsResult<Series> {
let out = match inputs[0].dtype() {
DataType::Float32 => {
let start_lat = inputs[0].f32().unwrap();
Expand All @@ -60,50 +61,30 @@ fn haversine(inputs: &[Series], _kwargs: Option<Kwargs>) -> PolarsResult<Series>
Ok(out)
}

fn map_err(msg: &str) -> PolarsError {
polars_err!(ComputeError: "{msg}")
/// The `DefaultKwargs` isn't very ergonomic as it doesn't validate any schema.
/// Provide your own kwargs struct with the proper schema and accept that type
/// in your plugin expression.
#[derive(Deserialize)]
pub struct MyKwargs {
float_arg: f64,
integer_arg: i64,
string_arg: String,
boolean_arg: bool,
}

#[polars_expr(output_type=Utf8)]
fn append_kwargs(input: &[Series], kwargs: Option<Kwargs>) -> PolarsResult<Series> {
fn append_kwargs(input: &[Series], kwargs: Option<MyKwargs>) -> PolarsResult<Series> {
let input = &input[0];
let kwargs = kwargs.ok_or_else(|| polars_err!(ComputeError: "expected kwargs"))?;

let float_arg = kwargs
.get("float_arg")
.ok_or_else(|| map_err("expected 'float_arg'"))?
.as_f64()
.ok_or_else(|| map_err("expected float"))?;
let integer_arg = kwargs
.get("integer_arg")
.ok_or_else(|| map_err("expected 'integer_arg'"))?
.as_i64()
.ok_or_else(|| map_err("expected integer"))?;
let string_arg = kwargs
.get("string_arg")
.ok_or_else(|| map_err("expected 'string_arg'"))?
.as_str()
.ok_or_else(|| map_err("expected string"))?;
let boolean_arg = kwargs
.get("boolean_arg")
.ok_or_else(|| map_err("expected 'boolean_arg'"))?
.as_bool()
.ok_or_else(|| map_err("expected boolean"))?;
let dict_arg = kwargs
.get("dict_arg")
.ok_or_else(|| map_err("expected 'dict_arg'"))?
.as_object()
.ok_or_else(|| map_err("expected dict"))?;

let kwargs = kwargs.unwrap();
let input = input.cast(&DataType::Utf8)?;
let ca = input.utf8().unwrap();

Ok(ca
.apply_to_buffer(|val, buf| {
write!(
buf,
"{}-{}-{}-{}-{}-{:?}",
val, float_arg, integer_arg, string_arg, boolean_arg, dict_arg
"{}-{}-{}-{}-{}",
val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg
)
.unwrap()
})
Expand Down
4 changes: 4 additions & 0 deletions example/derive_expression/expression_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
mod distances;
mod expressions;

#[global_allocator]
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;
4 changes: 1 addition & 3 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,20 @@
integer_arg=93,
boolean_arg=False,
string_arg="example",
dict_arg={"foo": "bar"}
)
)

print(out)


# Tests we can return errors from FFI.
# Tests we can return errors from FFI by passing wrong types.
try:
out.with_columns(
appended_args=pl.col("names").language.append_args(
float_arg=True,
integer_arg=True,
boolean_arg=True,
string_arg="example",
dict_arg={"foo": "bar"}
))
except pl.ComputeError as e:
assert "the plugin failed with message" in str(e)
19 changes: 13 additions & 6 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,25 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
#[no_mangle]
pub unsafe extern "C" fn #fn_name (
e: *mut polars_ffi::SeriesExport,
len: usize,
kwargs: *const std::os::raw::c_char,
input_len: usize,
kwargs_ptr: *const u8,
kwargs_len: usize,
return_value: *mut polars_ffi::SeriesExport
) {
let inputs = polars_ffi::import_series_buffer(e, len).unwrap();
let kwargs = std::ffi::CStr::from_ptr(kwargs).to_bytes();
let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap();

let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = if kwargs.is_empty() {
::std::option::Option::None
} else {
let value = pyo3_polars::derive::_parse_kwargs(kwargs);
::std::option::Option::Some(value)
match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => Some(value),
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
}
};

// define the function
Expand Down
6 changes: 4 additions & 2 deletions pyo3-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ polars-plan = { workspace = true, optional = true }
pyo3 = "0.19.0"
pyo3-polars-derive = { version = "0.1.0", path = "../pyo3-polars-derive", optional = true }
thiserror = "1"
serde_json = {version = "1.0", optional = true }
serde-pickle = {version = "1", optional = true }
serde = {version = "1", optional = true}


[features]
lazy = ["polars/serde-lazy", "polars-plan", "polars-lazy/serde", "ciborium"]
derive = ["pyo3-polars-derive", "polars-plan", "polars-ffi", "serde_json"]
derive = ["pyo3-polars-derive", "polars-plan", "polars-ffi", "serde-pickle", "serde"]
20 changes: 9 additions & 11 deletions pyo3-polars/src/derive.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
use polars::prelude::PolarsError;
use polars_core::error::{to_compute_err, PolarsResult};
pub use pyo3_polars_derive::polars_expr;
pub use serde_json;
pub use serde_json::{Map, Value};
use serde::Deserialize;
use std::cell::RefCell;
use std::ffi::CString;
pub type Kwargs = serde_json::Map<String, Value>;

pub type DefaultKwargs = serde_pickle::Value;

thread_local! {
static LAST_ERROR: RefCell<CString> = RefCell::new(CString::default());
}

pub unsafe fn _parse_kwargs(kwargs: &[u8]) -> Kwargs {
let kwargs = std::str::from_utf8_unchecked(kwargs);
let value = serde_json::from_str(kwargs).unwrap();
if let Value::Object(kwargs) = value {
return kwargs;
} else {
panic!("expected kwargs dictionary")
}
pub unsafe fn _parse_kwargs<'a, T>(kwargs: &'a [u8]) -> PolarsResult<T>
where
T: Deserialize<'a>,
{
serde_pickle::from_slice(kwargs, Default::default()).map_err(to_compute_err)
}

pub fn _update_last_error(err: PolarsError) {
Expand Down
7 changes: 0 additions & 7 deletions pyo3-polars/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fmt::{Debug, Formatter};

use polars::prelude::PolarsError;
use polars_core::error::ArrowError;
use pyo3::create_exception;
use pyo3::exceptions::{PyException, PyIOError, PyIndexError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
Expand All @@ -13,8 +12,6 @@ pub enum PyPolarsErr {
Polars(#[from] PolarsError),
#[error("{0}")]
Other(String),
#[error(transparent)]
Arrow(#[from] ArrowError),
}

impl std::convert::From<PyPolarsErr> for PyErr {
Expand All @@ -31,7 +28,6 @@ impl std::convert::From<PyPolarsErr> for PyErr {
PolarsError::Io(err) => PyIOError::new_err(err.to_string()),
PolarsError::OutOfBounds(err) => PyIndexError::new_err(err.to_string()),
PolarsError::InvalidOperation(err) => PyValueError::new_err(err.to_string()),
PolarsError::ArrowError(err) => ArrowErrorException::new_err(format!("{:?}", err)),
PolarsError::Duplicate(err) => DuplicateError::new_err(err.to_string()),
PolarsError::ColumnNotFound(err) => ColumnNotFound::new_err(err.to_string()),
PolarsError::SchemaFieldNotFound(err) => {
Expand All @@ -44,7 +40,6 @@ impl std::convert::From<PyPolarsErr> for PyErr {
StringCacheMismatchError::new_err(err.to_string())
}
},
Arrow(err) => ArrowErrorException::new_err(format!("{:?}", err)),
_ => default(),
}
}
Expand All @@ -56,7 +51,6 @@ impl Debug for PyPolarsErr {
match self {
Polars(err) => write!(f, "{:?}", err),
Other(err) => write!(f, "BindingsError: {:?}", err),
Arrow(err) => write!(f, "{:?}", err),
}
}
}
Expand All @@ -66,7 +60,6 @@ create_exception!(exceptions, SchemaFieldNotFound, PyException);
create_exception!(exceptions, StructFieldNotFound, PyException);
create_exception!(exceptions, ComputeError, PyException);
create_exception!(exceptions, NoDataError, PyException);
create_exception!(exceptions, ArrowErrorException, PyException);
create_exception!(exceptions, ShapeError, PyException);
create_exception!(exceptions, SchemaError, PyException);
create_exception!(exceptions, DuplicateError, PyException);
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly-2023-10-12"

0 comments on commit 8ab4f14

Please sign in to comment.