diff --git a/crates/sparrow-compiler/src/ast_to_dfg.rs b/crates/sparrow-compiler/src/ast_to_dfg.rs index 139528820..8d0e7f80c 100644 --- a/crates/sparrow-compiler/src/ast_to_dfg.rs +++ b/crates/sparrow-compiler/src/ast_to_dfg.rs @@ -17,7 +17,7 @@ use itertools::{izip, Itertools}; use record_ops_to_dfg::*; use smallvec::{smallvec, SmallVec}; use sparrow_arrow::scalar_value::ScalarValue; -use sparrow_instructions::CastEvaluator; +use sparrow_instructions::{CastEvaluator, Udf}; use sparrow_instructions::{GroupId, InstKind, InstOp}; use sparrow_syntax::{ Collection, ExprOp, FenlType, FormatDataType, LiteralValue, Located, Location, Resolved, @@ -29,7 +29,7 @@ use crate::dfg::{Dfg, Expression, Operation}; use crate::diagnostics::DiagnosticCode; use crate::time_domain::TimeDomain; use crate::types::inference::instantiate; -use crate::{DataContext, DiagnosticBuilder, DiagnosticCollector}; +use crate::{DataContext, DiagnosticBuilder, DiagnosticCollector, TimeDomainCheck}; /// Convert the `expr` to corresponding DFG nodes. pub(super) fn ast_to_dfg( @@ -95,6 +95,69 @@ pub(super) fn ast_to_dfg( ) } +pub fn add_udf_to_dfg( + location: &Located, + udf: Arc, + dfg: &mut Dfg, + arguments: Resolved>, + data_context: &mut DataContext, + diagnostics: &mut DiagnosticCollector<'_>, +) -> anyhow::Result { + let argument_types = arguments.transform(|i| i.with_value(i.value_type().clone())); + let signature = udf.signature(); + + let (instantiated_types, instantiated_result_type) = + match instantiate(location, &argument_types, signature) { + Ok(result) => result, + Err(diagnostic) => { + diagnostic.emit(diagnostics); + return Ok(dfg.error_node()); + } + }; + + if argument_types.iter().any(|arg| arg.is_error()) { + return Ok(dfg.error_node()); + } + let grouping = verify_same_partitioning( + data_context, + diagnostics, + &location.with_value(location.inner().as_str()), + &arguments, + )?; + + let args: Vec<_> = izip!(arguments, instantiated_types) + .map(|(arg, expected_type)| -> anyhow::Result<_> { + let ast_dfg = Arc::new(AstDfg::new( + cast_if_needed(dfg, arg.value(), arg.value_type(), &expected_type)?, + arg.is_new(), + expected_type, + arg.grouping(), + arg.time_domain().clone(), + arg.location().clone(), + None, + )); + Ok(arg.with_value(ast_dfg)) + }) + .try_collect()?; + + let is_new = dfg.add_udf(udf.clone(), args.iter().map(|i| i.value()).collect())?; + let value = is_any_new(dfg, &args)?; + + let time_domain_check = TimeDomainCheck::Compatible; + let time_domain = + time_domain_check.check_args(location.location(), diagnostics, &args, data_context)?; + + Ok(Arc::new(AstDfg::new( + value, + is_new, + instantiated_result_type, + grouping, + time_domain, + location.location().clone(), + None, + ))) +} + pub fn add_to_dfg( data_context: &mut DataContext, dfg: &mut Dfg, diff --git a/crates/sparrow-compiler/src/dfg.rs b/crates/sparrow-compiler/src/dfg.rs index ad8c66836..45301d9c7 100644 --- a/crates/sparrow-compiler/src/dfg.rs +++ b/crates/sparrow-compiler/src/dfg.rs @@ -37,7 +37,7 @@ use hashbrown::HashMap; use itertools::{izip, Itertools}; pub(crate) use language::ChildrenVec; use sparrow_arrow::scalar_value::ScalarValue; -use sparrow_instructions::{InstKind, InstOp}; +use sparrow_instructions::{InstKind, InstOp, Udf}; use sparrow_syntax::{FenlType, Location}; pub(crate) use step_kind::*; type DfgGraph = egg::EGraph; @@ -125,6 +125,15 @@ impl Dfg { self.add_expression(Expression::Inst(InstKind::Simple(instruction)), children) } + /// Add a udf node to the DFG. + pub(super) fn add_udf( + &mut self, + udf: Arc, + children: ChildrenVec, + ) -> anyhow::Result { + self.add_expression(Expression::Inst(InstKind::Udf(udf)), children) + } + /// Add an expression to the DFG. pub(super) fn add_expression( &mut self, @@ -255,7 +264,6 @@ impl Dfg { // 2. The number of args should be correct. match expr { - Expression::Inst(InstKind::Udf(_)) => unimplemented!("udf unsupported"), Expression::Literal(_) | Expression::LateBound(_) => { anyhow::ensure!( children.len() == 1, @@ -267,6 +275,9 @@ impl Dfg { Expression::Inst(InstKind::Simple(op)) => op .signature() .assert_valid_argument_count(children.len() - 1), + Expression::Inst(InstKind::Udf(udf)) => udf + .signature() + .assert_valid_argument_count(children.len() - 1), Expression::Inst(InstKind::FieldRef) => { anyhow::ensure!( children.len() == 3, diff --git a/crates/sparrow-compiler/src/functions.rs b/crates/sparrow-compiler/src/functions.rs index eff086b97..5aa988431 100644 --- a/crates/sparrow-compiler/src/functions.rs +++ b/crates/sparrow-compiler/src/functions.rs @@ -20,6 +20,7 @@ pub use function::*; use implementation::*; pub(crate) use pushdown::*; pub use registry::*; +pub use time_domain_check::*; /// Register all the functions available in the registry. fn register_functions(registry: &mut Registry) { diff --git a/crates/sparrow-compiler/src/functions/time_domain_check.rs b/crates/sparrow-compiler/src/functions/time_domain_check.rs index 4967ef701..7748fde50 100644 --- a/crates/sparrow-compiler/src/functions/time_domain_check.rs +++ b/crates/sparrow-compiler/src/functions/time_domain_check.rs @@ -12,7 +12,7 @@ use crate::{AstDfgRef, DataContext, DiagnosticCollector}; /// are subject to change. It may be better to use a closure to allow defining /// the special behaviors as part of each function. #[derive(Default)] -pub(super) enum TimeDomainCheck { +pub enum TimeDomainCheck { /// The function requires the arguments to be compatible, and returns /// the resulting time domain. /// @@ -47,7 +47,7 @@ pub(super) enum TimeDomainCheck { } impl TimeDomainCheck { - pub(super) fn check_args( + pub fn check_args( &self, location: &Location, diagnostics: &mut DiagnosticCollector<'_>, diff --git a/crates/sparrow-compiler/src/types/instruction.rs b/crates/sparrow-compiler/src/types/instruction.rs index 8780af7b7..08de8c816 100644 --- a/crates/sparrow-compiler/src/types/instruction.rs +++ b/crates/sparrow-compiler/src/types/instruction.rs @@ -22,7 +22,6 @@ pub(crate) fn typecheck_inst( argument_literals: &[Option], ) -> anyhow::Result { match inst { - InstKind::Udf(_) => unimplemented!("udf type checking unsupported"), InstKind::Simple(instruction) => { let signature = instruction.signature(); let argument_types = Resolved::new( @@ -33,6 +32,16 @@ pub(crate) fn typecheck_inst( validate_instantiation(&argument_types, signature) } + InstKind::Udf(udf) => { + let signature = udf.signature(); + let argument_types = Resolved::new( + Cow::Owned(signature.parameters().names().to_owned()), + argument_types, + signature.parameters().has_vararg, + ); + + validate_instantiation(&argument_types, signature) + } InstKind::FieldRef => { anyhow::ensure!( argument_types.len() == 2, diff --git a/crates/sparrow-instructions/src/inst.rs b/crates/sparrow-instructions/src/inst.rs index c4bf4aec9..3fa209b5a 100644 --- a/crates/sparrow-instructions/src/inst.rs +++ b/crates/sparrow-instructions/src/inst.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::hash::Hash; use std::sync::Arc; use arrow::datatypes::DataType; @@ -261,12 +262,15 @@ impl PartialEq for InstKind { } } -use std::hash::Hash; impl Hash for InstKind { fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); match self { InstKind::Udf(udf) => udf.hash(state), - _ => core::mem::discriminant(self).hash(state), + InstKind::Simple(op) => op.hash(state), + InstKind::Cast(dt) => dt.hash(state), + InstKind::Record => {} + InstKind::FieldRef => {} } } } diff --git a/crates/sparrow-session/src/session.rs b/crates/sparrow-session/src/session.rs index 1da88c21e..77c5f3640 100644 --- a/crates/sparrow-session/src/session.rs +++ b/crates/sparrow-session/src/session.rs @@ -11,7 +11,7 @@ use sparrow_api::kaskada::v1alpha::{ ComputeTable, FeatureSet, PerEntityBehavior, TableConfig, TableMetadata, }; use sparrow_compiler::{AstDfgRef, CompilerOptions, DataContext, Dfg, DiagnosticCollector}; -use sparrow_instructions::GroupId; +use sparrow_instructions::{GroupId, Udf}; use sparrow_runtime::execute::output::Destination; use sparrow_runtime::key_hash_inverse::ThreadSafeKeyHashInverse; use sparrow_syntax::{ExprOp, FenlType, LiteralValue, Located, Location, Resolved}; @@ -277,6 +277,57 @@ impl Session { } } + /// The [Expr] will call this to add a user-defined-function to the DFG directly. + /// + /// This bypasses much of the plumbing of the [ExprOp] required due to our construction + /// of the AST. + #[allow(unused)] + fn add_udf_to_dfg( + &mut self, + udf: Arc, + args: Vec, + ) -> error_stack::Result { + let signature = udf.signature(); + let arg_names = signature.arg_names().to_owned(); + signature.assert_valid_argument_count(args.len()); + + let has_vararg = + signature.parameters().has_vararg && args.len() > signature.arg_names().len(); + let args = Resolved::new( + arg_names.into(), + args.into_iter() + .map(|arg| Located::builder(arg.0)) + .collect(), + has_vararg, + ); + let feature_set = FeatureSet::default(); + let mut diagnostics = DiagnosticCollector::new(&feature_set); + + let location = Located::builder("udf".to_owned()); + let result = sparrow_compiler::add_udf_to_dfg( + &location, + udf.clone(), + &mut self.dfg, + args, + &mut self.data_context, + &mut diagnostics, + ) + .into_report() + .change_context(Error::Invalid)?; + + if diagnostics.num_errors() > 0 { + let errors = diagnostics + .finish() + .into_iter() + .filter(|diagnostic| diagnostic.is_error()) + .map(|diagnostic| diagnostic.formatted) + .collect(); + Err(Error::Errors(errors))? + } else { + Ok(result) + } + } + pub fn execute( &self, expr: &Expr, diff --git a/python/mypy.ini b/python/mypy.ini index a37604c9f..07bbced46 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -1,6 +1,6 @@ [mypy] -[mypy-desert,marshmallow,nox.*,pytest,pytest_mock,_pytest.*] +[mypy-plotly.*,mypy-desert,marshmallow,nox.*,pytest,pytest_mock,_pytest.*] ignore_missing_imports = True # pyarrow doesn't currently expose mypy stubs: diff --git a/python/pysrc/kaskada/_timestream.py b/python/pysrc/kaskada/_timestream.py index 2750ae66f..f5d761cae 100644 --- a/python/pysrc/kaskada/_timestream.py +++ b/python/pysrc/kaskada/_timestream.py @@ -112,6 +112,7 @@ def make_arg(arg: Union[Timestream, Literal]) -> _ffi.Expr: ffi_args = [make_arg(arg) for arg in args] try: return Timestream( + # TODO: FRAZ - so I need a `call` that can take the udf _ffi.Expr.call(session=session, operation=func, args=ffi_args) ) except TypeError as e: diff --git a/python/src/expr.rs b/python/src/expr.rs index f1f3ef16f..0acc47464 100644 --- a/python/src/expr.rs +++ b/python/src/expr.rs @@ -31,6 +31,7 @@ impl Expr { let mut rust_session = session.rust_session()?; let args: Vec<_> = args.into_iter().map(|e| e.rust_expr).collect(); + // TODO: - Support adding a UDF here. let rust_expr = match rust_session.add_expr(&operation, args) { Ok(node) => node, Err(e) => {