From 64c9709ba42f27eeafbeeb88cdff41cac9207f4c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 17 Oct 2023 16:54:26 -0400 Subject: [PATCH] Add random missing bindings --- src/expr/aggregate.rs | 49 +++++++++++++++++++++++++++++++++++++++++++ src/expr/join.rs | 14 +++++++++++++ src/expr/sort.rs | 4 ++++ 3 files changed, 67 insertions(+) diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index c3de9673a..5ebf8c6cf 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -16,12 +16,15 @@ // under the License. use datafusion_common::DataFusionError; +use datafusion_expr::expr::{AggregateFunction, AggregateUDF, Alias}; use datafusion_expr::logical_plan::Aggregate; +use datafusion_expr::Expr; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; use super::logical_node::LogicalNode; use crate::common::df_schema::PyDFSchema; +use crate::errors::py_type_err; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; @@ -84,6 +87,24 @@ impl PyAggregate { .collect()) } + /// Returns the inner Aggregate Expr(s) + pub fn agg_expressions(&self) -> PyResult> { + Ok(self + .aggregate + .aggr_expr + .iter() + .map(|e| PyExpr::from(e.clone())) + .collect()) + } + + pub fn agg_func_name(&self, expr: PyExpr) -> PyResult { + Self::_agg_func_name(&expr.expr) + } + + pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult> { + self._aggregation_arguments(&expr.expr) + } + // Retrieves the input `LogicalPlan` to this `Aggregate` node fn input(&self) -> PyResult> { Ok(Self::inputs(self)) @@ -99,6 +120,34 @@ impl PyAggregate { } } +impl PyAggregate { + #[allow(clippy::only_used_in_recursion)] + fn _aggregation_arguments(&self, expr: &Expr) -> PyResult> { + match expr { + // TODO: This Alias logic seems to be returning some strange results that we should investigate + Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()), + Expr::AggregateFunction(AggregateFunction { fun: _, args, .. }) + | Expr::AggregateUDF(AggregateUDF { fun: _, args, .. }) => { + Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()) + } + _ => Err(py_type_err( + "Encountered a non Aggregate type in aggregation_arguments", + )), + } + } + + fn _agg_func_name(expr: &Expr) -> PyResult { + match expr { + Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()), + Expr::AggregateFunction(AggregateFunction { fun, .. }) => Ok(fun.to_string()), + Expr::AggregateUDF(AggregateUDF { fun, .. }) => Ok(fun.name.clone()), + _ => Err(py_type_err( + "Encountered a non Aggregate type in agg_func_name", + )), + } + } +} + impl LogicalNode for PyAggregate { fn inputs(&self) -> Vec { vec![PyLogicalPlan::from((*self.aggregate.input).clone())] diff --git a/src/expr/join.rs b/src/expr/join.rs index 801662962..a53ddd3ba 100644 --- a/src/expr/join.rs +++ b/src/expr/join.rs @@ -46,6 +46,10 @@ impl PyJoinType { pub fn is_outer(&self) -> bool { self.join_type.is_outer() } + + fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.join_type)) + } } impl Display for PyJoinType { @@ -72,6 +76,16 @@ impl From for JoinConstraint { } } +#[pymethods] +impl PyJoinConstraint { + fn __repr__(&self) -> PyResult { + match self.join_constraint { + JoinConstraint::On => Ok("On".to_string()), + JoinConstraint::Using => Ok("Using".to_string()), + } + } +} + #[pyclass(name = "Join", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyJoin { diff --git a/src/expr/sort.rs b/src/expr/sort.rs index 8843c638d..f9f9e5899 100644 --- a/src/expr/sort.rs +++ b/src/expr/sort.rs @@ -72,6 +72,10 @@ impl PySort { .collect()) } + fn get_fetch_val(&self) -> PyResult> { + Ok(self.sort.fetch) + } + /// Retrieves the input `LogicalPlan` to this `Sort` node fn input(&self) -> PyResult> { Ok(Self::inputs(self))