Skip to content

Commit

Permalink
Add random missing bindings (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdye64 authored Oct 18, 2023
1 parent 399fa75 commit c2768d8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/expr/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
// under the License.

use datafusion_common::DataFusionError;
use datafusion_expr::expr::{AggregateFunction, AggregateUDF, Alias};
use datafusion_expr::logical_plan::Aggregate;
use datafusion_expr::Expr;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};

use super::logical_node::LogicalNode;
use crate::common::df_schema::PyDFSchema;
use crate::errors::py_type_err;
use crate::expr::PyExpr;
use crate::sql::logical::PyLogicalPlan;

Expand Down Expand Up @@ -84,6 +87,24 @@ impl PyAggregate {
.collect())
}

/// Returns the inner Aggregate Expr(s)
pub fn agg_expressions(&self) -> PyResult<Vec<PyExpr>> {
Ok(self
.aggregate
.aggr_expr
.iter()
.map(|e| PyExpr::from(e.clone()))
.collect())
}

pub fn agg_func_name(&self, expr: PyExpr) -> PyResult<String> {
Self::_agg_func_name(&expr.expr)
}

pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
self._aggregation_arguments(&expr.expr)
}

// Retrieves the input `LogicalPlan` to this `Aggregate` node
fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
Ok(Self::inputs(self))
Expand All @@ -99,6 +120,34 @@ impl PyAggregate {
}
}

impl PyAggregate {
#[allow(clippy::only_used_in_recursion)]
fn _aggregation_arguments(&self, expr: &Expr) -> PyResult<Vec<PyExpr>> {
match expr {
// TODO: This Alias logic seems to be returning some strange results that we should investigate
Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { fun: _, args, .. })
| Expr::AggregateUDF(AggregateUDF { fun: _, args, .. }) => {
Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect())
}
_ => Err(py_type_err(
"Encountered a non Aggregate type in aggregation_arguments",
)),
}
}

fn _agg_func_name(expr: &Expr) -> PyResult<String> {
match expr {
Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { fun, .. }) => Ok(fun.to_string()),
Expr::AggregateUDF(AggregateUDF { fun, .. }) => Ok(fun.name.clone()),
_ => Err(py_type_err(
"Encountered a non Aggregate type in agg_func_name",
)),
}
}
}

impl LogicalNode for PyAggregate {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![PyLogicalPlan::from((*self.aggregate.input).clone())]
Expand Down
14 changes: 14 additions & 0 deletions src/expr/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl PyJoinType {
pub fn is_outer(&self) -> bool {
self.join_type.is_outer()
}

fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.join_type))
}
}

impl Display for PyJoinType {
Expand All @@ -72,6 +76,16 @@ impl From<PyJoinConstraint> for JoinConstraint {
}
}

#[pymethods]
impl PyJoinConstraint {
fn __repr__(&self) -> PyResult<String> {
match self.join_constraint {
JoinConstraint::On => Ok("On".to_string()),
JoinConstraint::Using => Ok("Using".to_string()),
}
}
}

#[pyclass(name = "Join", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyJoin {
Expand Down
4 changes: 4 additions & 0 deletions src/expr/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ impl PySort {
.collect())
}

fn get_fetch_val(&self) -> PyResult<Option<usize>> {
Ok(self.sort.fetch)
}

/// Retrieves the input `LogicalPlan` to this `Sort` node
fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
Ok(Self::inputs(self))
Expand Down

0 comments on commit c2768d8

Please sign in to comment.