From ca83c8ee85170a36bab895e70686c181d250cbd8 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 20 Nov 2023 21:04:04 -0500 Subject: [PATCH 1/3] Adjust visibility of certain functions that cause issues for third party applications when they are private --- src/catalog.rs | 10 +++--- src/common/df_field.rs | 2 +- src/context.rs | 74 +++++++++++++++++++++--------------------- src/expr.rs | 1 + src/lib.rs | 2 +- src/substrait.rs | 12 +++---- src/utils.rs | 12 ++++++- 7 files changed, 62 insertions(+), 51 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 94faea067..ba7e22255 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -30,18 +30,18 @@ use datafusion::{ }; #[pyclass(name = "Catalog", module = "datafusion", subclass)] -pub(crate) struct PyCatalog { - catalog: Arc, +pub struct PyCatalog { + pub catalog: Arc, } #[pyclass(name = "Database", module = "datafusion", subclass)] -pub(crate) struct PyDatabase { - database: Arc, +pub struct PyDatabase { + pub database: Arc, } #[pyclass(name = "Table", module = "datafusion", subclass)] pub struct PyTable { - table: Arc, + pub table: Arc, } impl PyCatalog { diff --git a/src/common/df_field.rs b/src/common/df_field.rs index 703af0aa2..68c05361f 100644 --- a/src/common/df_field.rs +++ b/src/common/df_field.rs @@ -27,7 +27,7 @@ use super::data_type::PyDataType; #[pyclass(name = "DFField", module = "datafusion.common", subclass)] #[derive(Debug, Clone)] pub struct PyDFField { - field: DFField, + pub field: DFField, } impl From for DFField { diff --git a/src/context.rs b/src/context.rs index a0361acfc..764a96125 100644 --- a/src/context.rs +++ b/src/context.rs @@ -60,8 +60,8 @@ use tokio::task::JoinHandle; /// Configuration options for a SessionContext #[pyclass(name = "SessionConfig", module = "datafusion", subclass)] #[derive(Clone, Default)] -pub(crate) struct PySessionConfig { - pub(crate) config: SessionConfig, +pub struct PySessionConfig { + pub config: SessionConfig, } impl From for PySessionConfig { @@ -153,8 +153,8 @@ impl PySessionConfig { /// Runtime options for a SessionContext #[pyclass(name = "RuntimeConfig", module = "datafusion", subclass)] #[derive(Clone)] -pub(crate) struct PyRuntimeConfig { - pub(crate) config: RuntimeConfig, +pub struct PyRuntimeConfig { + pub config: RuntimeConfig, } #[pymethods] @@ -215,15 +215,15 @@ impl PyRuntimeConfig { /// multi-threaded execution engine to perform the execution. #[pyclass(name = "SessionContext", module = "datafusion", subclass)] #[derive(Clone)] -pub(crate) struct PySessionContext { - pub(crate) ctx: SessionContext, +pub struct PySessionContext { + pub ctx: SessionContext, } #[pymethods] impl PySessionContext { #[pyo3(signature = (config=None, runtime=None))] #[new] - fn new(config: Option, runtime: Option) -> PyResult { + pub fn new(config: Option, runtime: Option) -> PyResult { let config = if let Some(c) = config { c.config } else { @@ -242,7 +242,7 @@ impl PySessionContext { } /// Register a an object store with the given name - fn register_object_store( + pub fn register_object_store( &mut self, scheme: &str, store: &PyAny, @@ -276,13 +276,13 @@ impl PySessionContext { } /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str, py: Python) -> PyResult { + pub fn sql(&mut self, query: &str, py: Python) -> PyResult { let result = self.ctx.sql(query); let df = wait_for_future(py, result).map_err(DataFusionError::from)?; Ok(PyDataFrame::new(df)) } - fn create_dataframe( + pub fn create_dataframe( &mut self, partitions: PyArrowType>>, name: Option<&str>, @@ -314,13 +314,13 @@ impl PySessionContext { } /// Create a DataFrame from an existing logical plan - fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame { + pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame { PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone())) } /// Construct datafusion dataframe from Python list #[allow(clippy::wrong_self_convention)] - fn from_pylist( + pub fn from_pylist( &mut self, data: PyObject, name: Option<&str>, @@ -340,7 +340,7 @@ impl PySessionContext { /// Construct datafusion dataframe from Python dictionary #[allow(clippy::wrong_self_convention)] - fn from_pydict( + pub fn from_pydict( &mut self, data: PyObject, name: Option<&str>, @@ -360,7 +360,7 @@ impl PySessionContext { /// Construct datafusion dataframe from Arrow Table #[allow(clippy::wrong_self_convention)] - fn from_arrow_table( + pub fn from_arrow_table( &mut self, data: PyObject, name: Option<&str>, @@ -381,7 +381,7 @@ impl PySessionContext { /// Construct datafusion dataframe from pandas #[allow(clippy::wrong_self_convention)] - fn from_pandas( + pub fn from_pandas( &mut self, data: PyObject, name: Option<&str>, @@ -401,7 +401,7 @@ impl PySessionContext { /// Construct datafusion dataframe from polars #[allow(clippy::wrong_self_convention)] - fn from_polars( + pub fn from_polars( &mut self, data: PyObject, name: Option<&str>, @@ -417,21 +417,21 @@ impl PySessionContext { }) } - fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> { + pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> { self.ctx .register_table(name, table.table()) .map_err(DataFusionError::from)?; Ok(()) } - fn deregister_table(&mut self, name: &str) -> PyResult<()> { + pub fn deregister_table(&mut self, name: &str) -> PyResult<()> { self.ctx .deregister_table(name) .map_err(DataFusionError::from)?; Ok(()) } - fn register_record_batches( + pub fn register_record_batches( &mut self, name: &str, partitions: PyArrowType>>, @@ -451,7 +451,7 @@ impl PySessionContext { skip_metadata=true, schema=None, file_sort_order=None))] - fn register_parquet( + pub fn register_parquet( &mut self, name: &str, path: &str, @@ -489,7 +489,7 @@ impl PySessionContext { schema_infer_max_records=1000, file_extension=".csv", file_compression_type=None))] - fn register_csv( + pub fn register_csv( &mut self, name: &str, path: PathBuf, @@ -533,7 +533,7 @@ impl PySessionContext { file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] - fn register_json( + pub fn register_json( &mut self, name: &str, path: PathBuf, @@ -568,7 +568,7 @@ impl PySessionContext { file_extension=".avro", table_partition_cols=vec![], infinite=false))] - fn register_avro( + pub fn register_avro( &mut self, name: &str, path: PathBuf, @@ -595,7 +595,7 @@ impl PySessionContext { } // Registers a PyArrow.Dataset - fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { + pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?); self.ctx @@ -605,18 +605,18 @@ impl PySessionContext { Ok(()) } - fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { + pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { self.ctx.register_udf(udf.function); Ok(()) } - fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> { + pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> { self.ctx.register_udaf(udaf.function); Ok(()) } #[pyo3(signature = (name="datafusion"))] - fn catalog(&self, name: &str) -> PyResult { + pub fn catalog(&self, name: &str) -> PyResult { match self.ctx.catalog(name) { Some(catalog) => Ok(PyCatalog::new(catalog)), None => Err(PyKeyError::new_err(format!( @@ -626,31 +626,31 @@ impl PySessionContext { } } - fn tables(&self) -> HashSet { + pub fn tables(&self) -> HashSet { #[allow(deprecated)] self.ctx.tables().unwrap() } - fn table(&self, name: &str, py: Python) -> PyResult { + pub fn table(&self, name: &str, py: Python) -> PyResult { let x = wait_for_future(py, self.ctx.table(name)).map_err(DataFusionError::from)?; Ok(PyDataFrame::new(x)) } - fn table_exist(&self, name: &str) -> PyResult { + pub fn table_exist(&self, name: &str) -> PyResult { Ok(self.ctx.table_exist(name)?) } - fn empty_table(&self) -> PyResult { + pub fn empty_table(&self) -> PyResult { Ok(PyDataFrame::new(self.ctx.read_empty()?)) } - fn session_id(&self) -> String { + pub fn session_id(&self) -> String { self.ctx.session_id() } #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] - fn read_json( + pub fn read_json( &mut self, path: PathBuf, schema: Option>, @@ -689,7 +689,7 @@ impl PySessionContext { file_extension=".csv", table_partition_cols=vec![], file_compression_type=None))] - fn read_csv( + pub fn read_csv( &self, path: PathBuf, schema: Option>, @@ -741,7 +741,7 @@ impl PySessionContext { skip_metadata=true, schema=None, file_sort_order=None))] - fn read_parquet( + pub fn read_parquet( &self, path: &str, table_partition_cols: Vec<(String, String)>, @@ -771,7 +771,7 @@ impl PySessionContext { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))] - fn read_avro( + pub fn read_avro( &self, path: &str, schema: Option>, @@ -793,7 +793,7 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - fn read_table(&self, table: &PyTable) -> PyResult { + pub fn read_table(&self, table: &PyTable) -> PyResult { let df = self .ctx .read_table(table.table()) diff --git a/src/expr.rs b/src/expr.rs index bbab8bf41..3875fb381 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -128,6 +128,7 @@ impl PyExpr { Expr::ScalarVariable(data_type, variables) => { Ok(PyScalarVariable::new(data_type, variables).into_py(py)) } + Expr::Like(value) => Ok(PyLike::from(value.clone()).into_py(py)), Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_py(py)), Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_py(py)), Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_py(py)), diff --git a/src/lib.rs b/src/lib.rs index 413b2a429..5e57db9cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ pub mod common; #[allow(clippy::borrow_deref_ref)] mod config; #[allow(clippy::borrow_deref_ref)] -mod context; +pub mod context; #[allow(clippy::borrow_deref_ref)] mod dataframe; mod dataset; diff --git a/src/substrait.rs b/src/substrait.rs index 73606fdfa..ff83f6f79 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -29,8 +29,8 @@ use prost::Message; #[pyclass(name = "plan", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] -pub(crate) struct PyPlan { - pub(crate) plan: Plan, +pub struct PyPlan { + pub plan: Plan, } #[pymethods] @@ -61,7 +61,7 @@ impl From for PyPlan { /// to a valid `LogicalPlan` instance. #[pyclass(name = "serde", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] -pub(crate) struct PySubstraitSerializer; +pub struct PySubstraitSerializer; #[pymethods] impl PySubstraitSerializer { @@ -107,7 +107,7 @@ impl PySubstraitSerializer { #[pyclass(name = "producer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] -pub(crate) struct PySubstraitProducer; +pub struct PySubstraitProducer; #[pymethods] impl PySubstraitProducer { @@ -123,7 +123,7 @@ impl PySubstraitProducer { #[pyclass(name = "consumer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] -pub(crate) struct PySubstraitConsumer; +pub struct PySubstraitConsumer; #[pymethods] impl PySubstraitConsumer { @@ -140,7 +140,7 @@ impl PySubstraitConsumer { } } -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { +pub fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/utils.rs b/src/utils.rs index 427a8a064..e31765230 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -25,7 +25,17 @@ use tokio::runtime::Runtime; /// Utility to get the Tokio Runtime from Python pub(crate) fn get_tokio_runtime(py: Python) -> PyRef { let datafusion = py.import("datafusion._internal").unwrap(); - datafusion.getattr("runtime").unwrap().extract().unwrap() + let tmp = datafusion.getattr("runtime").unwrap(); + match tmp.extract::>() { + Ok(runtime) => { + runtime + }, + Err(_e) => { + let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap()); + let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py); + obj.extract().unwrap() + } + } } /// Utility to collect rust futures with GIL released From e1f7d623a783b6761ab6e8d96171e1e30c08dce9 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 21 Nov 2023 10:54:11 -0500 Subject: [PATCH 2/3] linter fixes --- src/context.rs | 5 ++++- src/utils.rs | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/context.rs b/src/context.rs index 764a96125..e075cccda 100644 --- a/src/context.rs +++ b/src/context.rs @@ -223,7 +223,10 @@ pub struct PySessionContext { impl PySessionContext { #[pyo3(signature = (config=None, runtime=None))] #[new] - pub fn new(config: Option, runtime: Option) -> PyResult { + pub fn new( + config: Option, + runtime: Option + ) -> PyResult { let config = if let Some(c) = config { c.config } else { diff --git a/src/utils.rs b/src/utils.rs index e31765230..c5965bd2f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -27,9 +27,7 @@ pub(crate) fn get_tokio_runtime(py: Python) -> PyRef { let datafusion = py.import("datafusion._internal").unwrap(); let tmp = datafusion.getattr("runtime").unwrap(); match tmp.extract::>() { - Ok(runtime) => { - runtime - }, + Ok(runtime) => runtime, Err(_e) => { let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap()); let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py); From 6b62e5ab7e8b4aacfb41247889e1d4dde50c0e70 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 21 Nov 2023 11:01:31 -0500 Subject: [PATCH 3/3] linter fixes --- src/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context.rs b/src/context.rs index e075cccda..fb7ca8385 100644 --- a/src/context.rs +++ b/src/context.rs @@ -225,7 +225,7 @@ impl PySessionContext { #[new] pub fn new( config: Option, - runtime: Option + runtime: Option, ) -> PyResult { let config = if let Some(c) = config { c.config