Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: udf node creation in dfg #681

Merged
merged 8 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use itertools::{izip, Itertools};
use record_ops_to_dfg::*;
use smallvec::{smallvec, SmallVec};
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_instructions::CastEvaluator;
use sparrow_instructions::{CastEvaluator, Udf};
use sparrow_instructions::{GroupId, InstKind, InstOp};
use sparrow_syntax::{
Collection, ExprOp, FenlType, FormatDataType, LiteralValue, Located, Location, Resolved,
Expand All @@ -29,7 +29,7 @@ use crate::dfg::{Dfg, Expression, Operation};
use crate::diagnostics::DiagnosticCode;
use crate::time_domain::TimeDomain;
use crate::types::inference::instantiate;
use crate::{DataContext, DiagnosticBuilder, DiagnosticCollector};
use crate::{DataContext, DiagnosticBuilder, DiagnosticCollector, TimeDomainCheck};

/// Convert the `expr` to corresponding DFG nodes.
pub(super) fn ast_to_dfg(
Expand Down Expand Up @@ -95,6 +95,69 @@ pub(super) fn ast_to_dfg(
)
}

pub fn add_udf_to_dfg(
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
location: &Located<String>,
udf: Arc<dyn Udf>,
dfg: &mut Dfg,
arguments: Resolved<Located<AstDfgRef>>,
data_context: &mut DataContext,
diagnostics: &mut DiagnosticCollector<'_>,
) -> anyhow::Result<AstDfgRef> {
let argument_types = arguments.transform(|i| i.with_value(i.value_type().clone()));
let signature = udf.signature();

let (instantiated_types, instantiated_result_type) =
match instantiate(location, &argument_types, signature) {
Ok(result) => result,
Err(diagnostic) => {
diagnostic.emit(diagnostics);
return Ok(dfg.error_node());
}
};

if argument_types.iter().any(|arg| arg.is_error()) {
return Ok(dfg.error_node());
}
let grouping = verify_same_partitioning(
data_context,
diagnostics,
&location.with_value(location.inner().as_str()),
&arguments,
)?;

let args: Vec<_> = izip!(arguments, instantiated_types)
.map(|(arg, expected_type)| -> anyhow::Result<_> {
let ast_dfg = Arc::new(AstDfg::new(
cast_if_needed(dfg, arg.value(), arg.value_type(), &expected_type)?,
arg.is_new(),
expected_type,
arg.grouping(),
arg.time_domain().clone(),
arg.location().clone(),
None,
));
Ok(arg.with_value(ast_dfg))
})
.try_collect()?;

let is_new = dfg.add_udf(udf.clone(), args.iter().map(|i| i.value()).collect())?;
let value = is_any_new(dfg, &args)?;

let time_domain_check = TimeDomainCheck::Compatible;
let time_domain =
time_domain_check.check_args(location.location(), diagnostics, &args, data_context)?;

Ok(Arc::new(AstDfg::new(
value,
is_new,
instantiated_result_type,
grouping,
time_domain,
location.location().clone(),
None,
)))
}

pub fn add_to_dfg(
data_context: &mut DataContext,
dfg: &mut Dfg,
Expand Down
15 changes: 13 additions & 2 deletions crates/sparrow-compiler/src/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use hashbrown::HashMap;
use itertools::{izip, Itertools};
pub(crate) use language::ChildrenVec;
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_instructions::{InstKind, InstOp};
use sparrow_instructions::{InstKind, InstOp, Udf};
use sparrow_syntax::{FenlType, Location};
pub(crate) use step_kind::*;
type DfgGraph = egg::EGraph<language::DfgLang, analysis::DfgAnalysis>;
Expand Down Expand Up @@ -125,6 +125,15 @@ impl Dfg {
self.add_expression(Expression::Inst(InstKind::Simple(instruction)), children)
}

/// Add a udf node to the DFG.
pub(super) fn add_udf(
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
udf: Arc<dyn Udf>,
children: ChildrenVec,
) -> anyhow::Result<Id> {
self.add_expression(Expression::Inst(InstKind::Udf(udf)), children)
}

/// Add an expression to the DFG.
pub(super) fn add_expression(
&mut self,
Expand Down Expand Up @@ -255,7 +264,6 @@ impl Dfg {

// 2. The number of args should be correct.
match expr {
Expression::Inst(InstKind::Udf(_)) => unimplemented!("udf unsupported"),
Expression::Literal(_) | Expression::LateBound(_) => {
anyhow::ensure!(
children.len() == 1,
Expand All @@ -267,6 +275,9 @@ impl Dfg {
Expression::Inst(InstKind::Simple(op)) => op
.signature()
.assert_valid_argument_count(children.len() - 1),
Expression::Inst(InstKind::Udf(udf)) => udf
.signature()
.assert_valid_argument_count(children.len() - 1),
Expression::Inst(InstKind::FieldRef) => {
anyhow::ensure!(
children.len() == 3,
Expand Down
1 change: 1 addition & 0 deletions crates/sparrow-compiler/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use function::*;
use implementation::*;
pub(crate) use pushdown::*;
pub use registry::*;
pub use time_domain_check::*;

/// Register all the functions available in the registry.
fn register_functions(registry: &mut Registry) {
Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-compiler/src/functions/time_domain_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{AstDfgRef, DataContext, DiagnosticCollector};
/// are subject to change. It may be better to use a closure to allow defining
/// the special behaviors as part of each function.
#[derive(Default)]
pub(super) enum TimeDomainCheck {
pub enum TimeDomainCheck {
/// The function requires the arguments to be compatible, and returns
/// the resulting time domain.
///
Expand Down Expand Up @@ -47,7 +47,7 @@ pub(super) enum TimeDomainCheck {
}

impl TimeDomainCheck {
pub(super) fn check_args(
pub fn check_args(
&self,
location: &Location,
diagnostics: &mut DiagnosticCollector<'_>,
Expand Down
11 changes: 10 additions & 1 deletion crates/sparrow-compiler/src/types/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub(crate) fn typecheck_inst(
argument_literals: &[Option<ScalarValue>],
) -> anyhow::Result<FenlType> {
match inst {
InstKind::Udf(_) => unimplemented!("udf type checking unsupported"),
InstKind::Simple(instruction) => {
let signature = instruction.signature();
let argument_types = Resolved::new(
Expand All @@ -33,6 +32,16 @@ pub(crate) fn typecheck_inst(

validate_instantiation(&argument_types, signature)
}
InstKind::Udf(udf) => {
let signature = udf.signature();
let argument_types = Resolved::new(
Cow::Owned(signature.parameters().names().to_owned()),
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
argument_types,
signature.parameters().has_vararg,
);

validate_instantiation(&argument_types, signature)
}
InstKind::FieldRef => {
anyhow::ensure!(
argument_types.len() == 2,
Expand Down
8 changes: 6 additions & 2 deletions crates/sparrow-instructions/src/inst.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;

use arrow::datatypes::DataType;
Expand Down Expand Up @@ -261,12 +262,15 @@ impl PartialEq for InstKind {
}
}

use std::hash::Hash;
impl Hash for InstKind {
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
InstKind::Udf(udf) => udf.hash(state),
_ => core::mem::discriminant(self).hash(state),
InstKind::Simple(op) => op.hash(state),
InstKind::Cast(dt) => dt.hash(state),
InstKind::Record => {}
InstKind::FieldRef => {}
}
}
}
Expand Down
53 changes: 52 additions & 1 deletion crates/sparrow-session/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use sparrow_api::kaskada::v1alpha::{
ComputeTable, FeatureSet, PerEntityBehavior, TableConfig, TableMetadata,
};
use sparrow_compiler::{AstDfgRef, CompilerOptions, DataContext, Dfg, DiagnosticCollector};
use sparrow_instructions::GroupId;
use sparrow_instructions::{GroupId, Udf};
use sparrow_runtime::execute::output::Destination;
use sparrow_runtime::key_hash_inverse::ThreadSafeKeyHashInverse;
use sparrow_syntax::{ExprOp, FenlType, LiteralValue, Located, Location, Resolved};
Expand Down Expand Up @@ -277,6 +277,57 @@ impl Session {
}
}

/// The [Expr] will call this to add a user-defined-function to the DFG directly.
///
/// This bypasses much of the plumbing of the [ExprOp] required due to our construction
/// of the AST.
#[allow(unused)]
fn add_udf_to_dfg(
&mut self,
udf: Arc<dyn Udf>,
args: Vec<Expr>,
) -> error_stack::Result<AstDfgRef, Error> {
let signature = udf.signature();
let arg_names = signature.arg_names().to_owned();
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
signature.assert_valid_argument_count(args.len());

let has_vararg =
signature.parameters().has_vararg && args.len() > signature.arg_names().len();
let args = Resolved::new(
arg_names.into(),
args.into_iter()
.map(|arg| Located::builder(arg.0))
.collect(),
has_vararg,
);
let feature_set = FeatureSet::default();
let mut diagnostics = DiagnosticCollector::new(&feature_set);

let location = Located::builder("udf".to_owned());
let result = sparrow_compiler::add_udf_to_dfg(
&location,
udf.clone(),
&mut self.dfg,
args,
&mut self.data_context,
&mut diagnostics,
)
.into_report()
.change_context(Error::Invalid)?;

if diagnostics.num_errors() > 0 {
let errors = diagnostics
.finish()
.into_iter()
.filter(|diagnostic| diagnostic.is_error())
.map(|diagnostic| diagnostic.formatted)
.collect();
Err(Error::Errors(errors))?
} else {
Ok(result)
}
}

pub fn execute(
&self,
expr: &Expr,
Expand Down
2 changes: 1 addition & 1 deletion python/mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]

[mypy-desert,marshmallow,nox.*,pytest,pytest_mock,_pytest.*]
[mypy-plotly.*,mypy-desert,marshmallow,nox.*,pytest,pytest_mock,_pytest.*]
ignore_missing_imports = True

# pyarrow doesn't currently expose mypy stubs:
Expand Down
1 change: 1 addition & 0 deletions python/pysrc/kaskada/_timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def make_arg(arg: Union[Timestream, Literal]) -> _ffi.Expr:
ffi_args = [make_arg(arg) for arg in args]
try:
return Timestream(
# TODO: FRAZ - so I need a `call` that can take the udf
_ffi.Expr.call(session=session, operation=func, args=ffi_args)
)
except TypeError as e:
Expand Down
1 change: 1 addition & 0 deletions python/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl Expr {

let mut rust_session = session.rust_session()?;
let args: Vec<_> = args.into_iter().map(|e| e.rust_expr).collect();
// TODO: - Support adding a UDF here.
let rust_expr = match rust_session.add_expr(&operation, args) {
Ok(node) => node,
Err(e) => {
Expand Down
Loading