diff --git a/air-script/tests/codegen/masm.rs b/air-script/tests/codegen/masm.rs index 983c2b60..2351a2c0 100644 --- a/air-script/tests/codegen/masm.rs +++ b/air-script/tests/codegen/masm.rs @@ -84,6 +84,16 @@ fn evaluators() { expected.assert_eq(&generated_masm); } +#[test] +fn functions() { + let generated_masm = Test::new("tests/functions/functions.air".to_string()) + .transpile(Target::Masm) + .unwrap(); + + let expected = expect_file!["../functions/functions.masm"]; + expected.assert_eq(&generated_masm); +} + #[test] fn variables() { let generated_masm = Test::new("tests/variables/variables.air".to_string()) diff --git a/air-script/tests/codegen/winterfell.rs b/air-script/tests/codegen/winterfell.rs index caad525f..36c39fe5 100644 --- a/air-script/tests/codegen/winterfell.rs +++ b/air-script/tests/codegen/winterfell.rs @@ -84,6 +84,16 @@ fn evaluators() { expected.assert_eq(&generated_air); } +#[test] +fn functions() { + let generated_air = Test::new("tests/functions/functions.air".to_string()) + .transpile(Target::Winterfell) + .unwrap(); + + let expected = expect_file!["../functions/functions.rs"]; + expected.assert_eq(&generated_air); +} + #[test] fn variables() { let generated_air = Test::new("tests/variables/variables.air".to_string()) diff --git a/air-script/tests/functions/functions.air b/air-script/tests/functions/functions.air new file mode 100644 index 00000000..7698e7bc --- /dev/null +++ b/air-script/tests/functions/functions.air @@ -0,0 +1,23 @@ +def FunctionsAir + +fn get_multiplicity_flags(s0: felt, s1: felt) -> felt[4]: + return [!s0 & !s1, s0 & !s1, !s0 & s1, s0 & s1] + +trace_columns: + main: [t, s0, s1, v] + aux: [b_range] + +public_inputs: + stack_inputs: [16] + +random_values: + alpha: [16] + +boundary_constraints: + enf v.first = 0 + +integrity_constraints: + let val = $alpha[0] + v + let f = get_multiplicity_flags(s0, s1) + let z = val^4 * f[3] + val^2 * f[2] + val * f[1] + f[0] + enf b_range' = b_range * (z * t - t + 1) \ No newline at end of file diff --git a/air-script/tests/functions/functions.masm b/air-script/tests/functions/functions.masm new file mode 100644 index 00000000..e02e16a2 --- /dev/null +++ b/air-script/tests/functions/functions.masm @@ -0,0 +1,161 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 0 main and 1 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 1 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for aux + padw mem_loadw.4294900072 drop drop padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + # push the accumulator to the stack + push.1 movdn.2 push.0 movdn.2 + # => [b1, b0, r1, r0, ...] + # square 2 times + dup.1 dup.1 ext2mul dup.1 dup.1 ext2mul + # multiply + dup.1 dup.1 movdn.5 movdn.5 + # => [b1, b0, r1, r0, b1, b0, ...] (4 cycles) + ext2mul movdn.3 movdn.3 + # => [b1, b0, r1', r0', ...] (5 cycles) + # clean stack + drop drop + # => [r1, r0, ...] (2 cycles) + padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add + # push the accumulator to the stack + push.1 movdn.2 push.0 movdn.2 + # => [b1, b0, r1, r0, ...] + # square 1 times + dup.1 dup.1 ext2mul + # multiply + dup.1 dup.1 movdn.5 movdn.5 + # => [b1, b0, r1, r0, b1, b0, ...] (4 cycles) + ext2mul movdn.3 movdn.3 + # => [b1, b0, r1', r0', ...] (5 cycles) + # clean stack + drop drop + # => [r1, r0, ...] (2 cycles) + push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2mul ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2mul ext2add push.1 push.0 padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop ext2sub push.1 push.0 ext2add ext2mul ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the main trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_main_first + # boundary constraint 0 for main + padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul +end # END PROC compute_boundary_constraints_main_first + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_main_first + # => [(first1, first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints \ No newline at end of file diff --git a/air-script/tests/functions/functions.rs b/air-script/tests/functions/functions.rs new file mode 100644 index 00000000..6c92d02c --- /dev/null +++ b/air-script/tests/functions/functions.rs @@ -0,0 +1,90 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + stack_inputs: [Felt; 16], +} + +impl PublicInputs { + pub fn new(stack_inputs: [Felt; 16]) -> Self { + Self { stack_inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.stack_inputs.as_slice()); + } +} + +pub struct FunctionsAir { + context: AirContext, + stack_inputs: [Felt; 16], +} + +impl FunctionsAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for FunctionsAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![]; + let aux_degrees = vec![TransitionConstraintDegree::new(8)]; + let num_main_assertions = 1; + let num_aux_assertions = 0; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, stack_inputs: public_inputs.stack_inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(3, 0, Felt::ZERO)); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + result[0] = aux_next[0] - aux_current[0] * (((aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(4_u64)) * E::from(main_current[1]) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])).exp(E::PositiveInteger::from(2_u64)) * (E::ONE - E::from(main_current[1])) * E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(main_current[3])) * E::from(main_current[1]) * (E::ONE - E::from(main_current[2])) + (E::ONE - E::from(main_current[1])) * (E::ONE - E::from(main_current[2]))) * E::from(main_current[0]) - E::from(main_current[0]) + E::ONE); + } +} \ No newline at end of file diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index c39c7134..c2b50db3 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -41,6 +41,10 @@ pub enum Declaration { /// /// Evaluator functions can be defined in any module of the program EvaluatorFunction(EvaluatorFunction), + /// A pure function definition + /// + /// Pure functions can be defined in any module of the program + Function(Function), /// A `periodic_columns` section declaration /// /// This may appear any number of times in the program, and may be declared in any module. @@ -523,3 +527,54 @@ impl PartialEq for EvaluatorFunction { self.name == other.name && self.params == other.params && self.body == other.body } } + +/// Functions take a group of expressions as parameters and returns a group of expressions. These +/// values can be a Felt, a Vector or a Matrix. Functions do not take trace bindings as parameters. +#[derive(Debug, Clone, Spanned)] +pub struct Function { + #[span] + pub span: SourceSpan, + pub name: Identifier, + pub params: Vec<(Identifier, Type)>, + pub return_type: Type, + pub body: Vec, +} +impl Function { + /// Creates a new function. + pub const fn new( + span: SourceSpan, + name: Identifier, + params: Vec<(Identifier, Type)>, + return_type: Type, + body: Vec, + ) -> Self { + Self { + span, + name, + params, + return_type, + body, + } + } + + pub fn param_bindings(&self) -> Vec { + self.params + .iter() + .map(|(name, _)| *name) + .collect::>() + } + + pub fn param_types(&self) -> Vec { + self.params.iter().map(|(_, ty)| *ty).collect::>() + } +} + +impl Eq for Function {} +impl PartialEq for Function { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.params == other.params + && self.return_type == other.return_type + && self.body == other.body + } +} diff --git a/parser/src/ast/display.rs b/parser/src/ast/display.rs index e73a3d36..64211c32 100644 --- a/parser/src/ast/display.rs +++ b/parser/src/ast/display.rs @@ -10,7 +10,7 @@ impl fmt::Display for DisplayBracketed { } } -/// Displays a slice of items surrounded by brackets, e.g. `[foo]` +/// Displays a slice of items surrounded by brackets, e.g. `[foo, bar]` pub struct DisplayList<'a, T>(pub &'a [T]); impl<'a, T: fmt::Display> fmt::Display for DisplayList<'a, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -26,7 +26,7 @@ impl fmt::Display for DisplayParenthesized { } } -/// Displays a slice of items surrounded by parentheses, e.g. `(foo)` +/// Displays a slice of items surrounded by parentheses, e.g. `(foo, bar)` pub struct DisplayTuple<'a, T>(pub &'a [T]); impl<'a, T: fmt::Display> fmt::Display for DisplayTuple<'a, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -34,6 +34,19 @@ impl<'a, T: fmt::Display> fmt::Display for DisplayTuple<'a, T> { } } +/// Displays a slice of items with their types surrounded by parentheses, +/// e.g. `(foo: felt, bar: felt[12])` +pub struct DisplayTypedTuple<'a, V, T>(pub &'a [(V, T)]); +impl<'a, V: fmt::Display, T: fmt::Display> fmt::Display for DisplayTypedTuple<'a, V, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({})", + DisplayCsv::new(self.0.iter().map(|(v, t)| format!("{}: {}", v, t))) + ) + } +} + /// Displays one or more items separated by commas, e.g. `foo, bar` pub struct DisplayCsv(Cell>); impl DisplayCsv @@ -97,6 +110,7 @@ impl<'a> fmt::Display for DisplayStatement<'a> { write!(f, "enf {}", expr) } Statement::Expr(ref expr) => write!(f, "{}", expr), + Statement::Return(ref expr) => writeln!(f, "return {}", expr), } } } diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index 5a9322eb..c4c6f200 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -71,6 +71,8 @@ pub struct Program { pub constants: BTreeMap, /// The set of used evaluator functions referenced in this program. pub evaluators: BTreeMap, + /// The set of used pure functions referenced in this program. + pub functions: BTreeMap, /// The set of used periodic columns referenced in this program. pub periodic_columns: BTreeMap, /// The set of public inputs defined in the root module @@ -115,6 +117,7 @@ impl Program { name, constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -265,7 +268,12 @@ impl Program { .entry(referenced) .or_insert_with(|| referenced_module.evaluators[&id].clone()); } - DependencyType::Function => unimplemented!(), + DependencyType::Function => { + program + .functions + .entry(referenced) + .or_insert_with(|| referenced_module.functions[&id].clone()); + } DependencyType::PeriodicColumn => { program .periodic_columns @@ -288,6 +296,7 @@ impl PartialEq for Program { self.name == other.name && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values @@ -377,6 +386,29 @@ impl fmt::Display for Program { f.write_str("\n")?; } + for (qid, function) in self.functions.iter() { + f.write_str("fn ")?; + if qid.module == self.name { + writeln!( + f, + "{}{}", + &qid.item, + DisplayTypedTuple(function.params.as_slice()) + )?; + } else { + writeln!( + f, + "{}{}", + qid, + DisplayTypedTuple(function.params.as_slice()) + )?; + } + + for statement in function.body.iter() { + writeln!(f, "{}", statement.display(1))?; + } + } + Ok(()) } } diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index 393f93db..c9363ca9 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -54,6 +54,7 @@ pub struct Module { pub imports: BTreeMap, pub constants: BTreeMap, pub evaluators: BTreeMap, + pub functions: BTreeMap, pub periodic_columns: BTreeMap, pub public_inputs: BTreeMap, pub random_values: Option, @@ -79,6 +80,7 @@ impl Module { imports: Default::default(), constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -121,6 +123,9 @@ impl Module { Declaration::EvaluatorFunction(evaluator) => { module.declare_evaluator(diagnostics, &mut names, evaluator)?; } + Declaration::Function(function) => { + module.declare_function(diagnostics, &mut names, function)?; + } Declaration::PeriodicColumns(mut columns) => { for column in columns.drain(..) { module.declare_periodic_column(diagnostics, &mut names, column)?; @@ -395,6 +400,22 @@ impl Module { Ok(()) } + fn declare_function( + &mut self, + diagnostics: &DiagnosticsHandler, + names: &mut HashSet, + function: Function, + ) -> Result<(), SemanticAnalysisError> { + if let Some(prev) = names.replace(NamespacedIdentifier::Function(function.name)) { + conflicting_declaration(diagnostics, "function", prev.span(), function.name.span()); + return Err(SemanticAnalysisError::NameConflict(function.name.span())); + } + + self.functions.insert(function.name, function); + + Ok(()) + } + fn declare_periodic_column( &mut self, diagnostics: &DiagnosticsHandler, @@ -621,6 +642,7 @@ impl PartialEq for Module { && self.imports == other.imports && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 78b01cd6..1e62cddd 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -60,6 +60,10 @@ pub enum Statement { /// Just like `Enforce`, except the constraint is contained in the body of a list comprehension, /// and must be enforced on every value produced by that comprehension. EnforceAll(ListComprehension), + /// Declares return values from a pure function. + /// + /// This is only valid in the body of a pure function, and must be the last statement in the body. + Return(Expr), } impl Statement { /// Checks this statement to see if it contains any constraints @@ -71,7 +75,7 @@ impl Statement { match self { Self::Enforce(_) | Self::EnforceIf(_, _) | Self::EnforceAll(_) => true, Self::Let(Let { body, .. }) => body.iter().any(|s| s.has_constraints()), - Self::Expr(_) => false, + Self::Expr(_) | Self::Return(_) => false, } } diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index a5cdfd4f..1833e564 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -72,9 +72,9 @@ impl Type { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Felt => f.write_str("field element"), - Self::Vector(n) => write!(f, "vector of length {}", n), - Self::Matrix(rows, cols) => write!(f, "matrix of {} rows and {} columns", rows, cols), + Self::Felt => f.write_str("felt"), + Self::Vector(n) => write!(f, "felt[{}]", n), + Self::Matrix(rows, cols) => write!(f, "felt[{}, {}]", rows, cols), } } } diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index d879ecaf..ba1f282c 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -122,6 +122,9 @@ pub trait VisitMut { ) -> ControlFlow { visit_mut_evaluator_function(self, expr) } + fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { + visit_mut_function(self, expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { visit_mut_periodic_column(self, expr) } @@ -223,6 +226,12 @@ pub trait VisitMut { fn visit_mut_identifier(&mut self, expr: &mut ast::Identifier) -> ControlFlow { visit_mut_identifier(self, expr) } + fn visit_mut_typed_identifier( + &mut self, + expr: &mut (ast::Identifier, ast::Type), + ) -> ControlFlow { + visit_mut_typed_identifier(self, expr) + } } impl<'a, V, T> VisitMut for &'a mut V @@ -244,6 +253,9 @@ where ) -> ControlFlow { (**self).visit_mut_evaluator_function(expr) } + fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { + (**self).visit_mut_function(expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { (**self).visit_mut_periodic_column(expr) } @@ -344,6 +356,12 @@ where fn visit_mut_identifier(&mut self, expr: &mut ast::Identifier) -> ControlFlow { (**self).visit_mut_identifier(expr) } + fn visit_mut_typed_identifier( + &mut self, + expr: &mut (ast::Identifier, ast::Type), + ) -> ControlFlow { + (**self).visit_mut_typed_identifier(expr) + } } pub fn visit_mut_module(visitor: &mut V, module: &mut ast::Module) -> ControlFlow @@ -359,6 +377,9 @@ where for evaluator in module.evaluators.values_mut() { visitor.visit_mut_evaluator_function(evaluator)?; } + for function in module.functions.values_mut() { + visitor.visit_mut_function(function)?; + } for column in module.periodic_columns.values_mut() { visitor.visit_mut_periodic_column(column)?; } @@ -439,6 +460,17 @@ where visitor.visit_mut_statement_block(&mut expr.body) } +pub fn visit_mut_function(visitor: &mut V, expr: &mut ast::Function) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + visitor.visit_mut_identifier(&mut expr.name)?; + for param in expr.params.iter_mut() { + visitor.visit_mut_typed_identifier(param)?; + } + visitor.visit_mut_statement_block(&mut expr.body) +} + pub fn visit_mut_evaluator_trace_segment( visitor: &mut V, expr: &mut ast::TraceSegment, @@ -530,6 +562,7 @@ where } ast::Statement::EnforceAll(ref mut expr) => visitor.visit_mut_enforce_all(expr), ast::Statement::Expr(ref mut expr) => visitor.visit_mut_expr(expr), + ast::Statement::Return(ref mut expr) => visitor.visit_mut_expr(expr), } } @@ -661,3 +694,13 @@ where { ControlFlow::Continue(()) } + +pub fn visit_mut_typed_identifier( + _visitor: &mut V, + _expr: &mut (ast::Identifier, ast::Type), +) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + ControlFlow::Continue(()) +} diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index cc9af75d..da19a9f4 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -113,6 +113,8 @@ pub enum Token { RandomValues, /// Keyword to declare the evaluator function section in the AIR constraints module. Ev, + /// Keyword to declare the function section in the AIR constraints module. + Fn, // BOUNDARY CONSTRAINT KEYWORDS // -------------------------------------------------------------------------------------------- @@ -137,9 +139,11 @@ pub enum Token { // -------------------------------------------------------------------------------------------- /// Keyword to signify that a constraint needs to be enforced Enf, + Return, Match, Case, When, + Felt, // PUNCTUATION // -------------------------------------------------------------------------------------------- @@ -161,6 +165,7 @@ pub enum Token { Ampersand, Bar, Bang, + Arrow, } impl Token { pub fn from_keyword_or_ident(s: &str) -> Self { @@ -177,6 +182,8 @@ impl Token { "periodic_columns" => Self::PeriodicColumns, "random_values" => Self::RandomValues, "ev" => Self::Ev, + "fn" => Self::Fn, + "felt" => Self::Felt, "boundary_constraints" => Self::BoundaryConstraints, "integrity_constraints" => Self::IntegrityConstraints, "first" => Self::First, @@ -184,6 +191,7 @@ impl Token { "for" => Self::For, "in" => Self::In, "enf" => Self::Enf, + "return" => Self::Return, "match" => Self::Match, "case" => Self::Case, "when" => Self::When, @@ -247,6 +255,8 @@ impl fmt::Display for Token { Self::PeriodicColumns => write!(f, "periodic_columns"), Self::RandomValues => write!(f, "random_values"), Self::Ev => write!(f, "ev"), + Self::Fn => write!(f, "fn"), + Self::Felt => write!(f, "felt"), Self::BoundaryConstraints => write!(f, "boundary_constraints"), Self::First => write!(f, "first"), Self::Last => write!(f, "last"), @@ -254,6 +264,7 @@ impl fmt::Display for Token { Self::For => write!(f, "for"), Self::In => write!(f, "in"), Self::Enf => write!(f, "enf"), + Self::Return => write!(f, "return"), Self::Match => write!(f, "match"), Self::Case => write!(f, "case"), Self::When => write!(f, "when"), @@ -275,6 +286,7 @@ impl fmt::Display for Token { Self::Ampersand => write!(f, "&"), Self::Bar => write!(f, "|"), Self::Bang => write!(f, "!"), + Self::Arrow => write!(f, "->"), } } } @@ -486,7 +498,10 @@ where ']' => pop!(self, Token::RBracket), '=' => pop!(self, Token::Equal), '+' => pop!(self, Token::Plus), - '-' => pop!(self, Token::Minus), + '-' => match self.peek() { + '>' => pop2!(self, Token::Arrow), + _ => pop!(self, Token::Minus), + }, '*' => pop!(self, Token::Star), '^' => pop!(self, Token::Caret), '&' => pop!(self, Token::Ampersand), diff --git a/parser/src/lexer/tests/functions.rs b/parser/src/lexer/tests/functions.rs new file mode 100644 index 00000000..4e5e49a3 --- /dev/null +++ b/parser/src/lexer/tests/functions.rs @@ -0,0 +1,83 @@ +use super::{expect_valid_tokenization, Symbol, Token}; + +// FUNCTION VALID TOKENIZATION +// ================================================================================================ + +#[test] +fn fn_with_scalars() { + let source = "fn fn_name(a: felt, b: felt) -> felt: + return a + b"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::Colon, + Token::Return, + Token::Ident(Symbol::intern("a")), + Token::Plus, + Token::Ident(Symbol::intern("b")), + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} + +#[test] +fn fn_with_vectors() { + let source = "fn fn_name(a: felt[12], b: felt[12]) -> felt[12]: + return [x + y for x, y in (a, b)]"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::Colon, + Token::Return, + Token::LBracket, + Token::Ident(Symbol::intern("x")), + Token::Plus, + Token::Ident(Symbol::intern("y")), + Token::For, + Token::Ident(Symbol::intern("x")), + Token::Comma, + Token::Ident(Symbol::intern("y")), + Token::In, + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::RParen, + Token::RBracket, + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} diff --git a/parser/src/lexer/tests/mod.rs b/parser/src/lexer/tests/mod.rs index 0426e68a..a312c007 100644 --- a/parser/src/lexer/tests/mod.rs +++ b/parser/src/lexer/tests/mod.rs @@ -6,6 +6,7 @@ mod arithmetic_ops; mod boundary_constraints; mod constants; mod evaluator_functions; +mod functions; mod identifiers; mod list_comprehension; mod modules; diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index ac0736a8..26a63e68 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -75,6 +75,7 @@ Declaration: Declaration = { PeriodicColumns => Declaration::PeriodicColumns(<>), RandomValues => Declaration::RandomValues(<>), EvaluatorFunction => Declaration::EvaluatorFunction(<>), + Function => Declaration::Function(<>), => Declaration::Trace(Span::new(span!(l, r), trace)), => Declaration::PublicInputs(<>), => Declaration::BoundaryConstraints(<>), @@ -256,6 +257,51 @@ EvaluatorSegmentBindings: (SourceSpan, Vec>) = { "[" "]" => (span!(l, r), vec![]), } +// FUNCTIONS +// ================================================================================================ + +Function: Function = { + "fn" "(" ")" "->" + ":" + => Function::new(span!(l, r), name, params, return_type, body) +} + +FunctionBindings: Vec<(Identifier, Type)> = { + > => params, +} + +FunctionBinding: (Identifier, Type) = { + ":" => (name, ty), +} + +FunctionBindingType: Type = { + "felt" => Type::Felt, + "felt" => Type::Vector(size as usize), + "felt" "[" "," "]" => Type::Matrix(row_size as usize, col_size as usize), +} + +FunctionBody: Vec = { + =>? { + if let Some(mut stmts) = let_stmts { + if stmts.iter().any(|stmt| !matches!(stmt, Statement::Let(_))) { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid function definition") + .with_primary_label(span!(l, r), "only let statements are allowed in function definitions") + .emit(); + return Err(ParseError::Failed.into()); + } + stmts.push(return_statement); + Ok(stmts) + } else { + Ok(vec![return_statement]) + } + }, +} + +ReturnStatement: Statement = { + "return" => Statement::Return(expr), +} + // BOUNDARY CONSTRAINTS // ================================================================================================ @@ -533,6 +579,15 @@ Matrix: Vec> = { Vector>, } +Tuple: Vec = { + "(" "," )*> ")" => { + let mut v = v; + v.insert(0, v2); + v.insert(0, v1); + v + } +}; + Size: u64 = { "[" "]" => <> } @@ -591,10 +646,13 @@ extern { "last" => Token::Last, "integrity_constraints" => Token::IntegrityConstraints, "ev" => Token::Ev, + "fn" => Token::Fn, "enf" => Token::Enf, + "return" => Token::Return, "match" => Token::Match, "case" => Token::Case, "when" => Token::When, + "felt" => Token::Felt, "'" => Token::Quote, "=" => Token::Equal, "+" => Token::Plus, @@ -613,5 +671,6 @@ extern { ")" => Token::RParen, "." => Token::Dot, ".." => Token::DotDot, + "->" => Token::Arrow, } } diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs new file mode 100644 index 00000000..803050bb --- /dev/null +++ b/parser/src/parser/tests/functions.rs @@ -0,0 +1,55 @@ +use miden_diagnostics::SourceSpan; + +use crate::ast::*; + +use super::ParseTest; + +// FUNCTIONS +// ================================================================================================ + +#[test] +fn fn_with_scalars() { + let source = " + mod test + + fn fn_with_scalars(a: felt, b: felt) -> felt: + return a + b"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_scalars), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_scalars), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], + Type::Felt, + vec![return_!(expr!(add!(access!(a), access!(b))))], + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_with_vectors() { + let source = " + mod test + + fn fn_with_vectors(a: felt[12], b: felt[12]) -> felt[12]: + return [x + y for (x, y) in (a, b)]"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_vectors), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_vectors), + vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], + Type::Vector(12), + vec![return_!(expr!( + lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => + add!(access!(x), access!(y))) + ))], + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index cd614fb9..444f0582 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -441,6 +441,12 @@ macro_rules! let_ { }; } +macro_rules! return_ { + ($value:expr) => { + Statement::Return($value) + }; +} + macro_rules! enforce { ($expr:expr) => { Statement::Enforce($expr) @@ -599,6 +605,7 @@ mod calls; mod constant_propagation; mod constants; mod evaluators; +mod functions; mod identifiers; mod inlining; mod integrity_constraints; diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 3bae4c32..40de33cb 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -253,6 +253,23 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { ); } + for (function_name, function) in module.functions.iter() { + let namespaced_name = NamespacedIdentifier::Function(*function_name); + if let Some((prev, _)) = self.imported.get_key_value(&namespaced_name) { + self.declaration_import_conflict(namespaced_name.span(), prev.span())?; + } + assert_eq!( + self.locals.insert( + namespaced_name, + BindingType::Function(FunctionType::Function( + function.param_types(), + function.return_type + )) + ), + None + ); + } + // Next, we add any periodic columns to the set of local bindings. // // These _can_ conflict with globally defined names, but are guaranteed not to conflict @@ -287,6 +304,10 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { self.visit_mut_evaluator_function(evaluator)?; } + for function in module.functions.values_mut() { + self.visit_mut_function(function)?; + } + if let Some(boundary_constraints) = module.boundary_constraints.as_mut() { if !boundary_constraints.is_empty() { self.visit_mut_boundary_constraints(boundary_constraints)?; @@ -363,6 +384,48 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { ControlFlow::Continue(()) } + fn visit_mut_function( + &mut self, + function: &mut Function, + ) -> ControlFlow { + // constraints are not allowed in pure functions + self.constraint_mode = ConstraintMode::None; + + // Start a new lexical scope + self.locals.enter(); + + // Track referenced imports in a new context, as we want to update the dependency graph + // for this function using only those imports referenced from this function body + let referenced = mem::take(&mut self.referenced); + + // Add the set of parameters to the current scope, check for conflicts + for (param, param_type) in function.params.iter_mut() { + let namespaced_name = NamespacedIdentifier::Binding(*param); + self.locals + .insert(namespaced_name, BindingType::Local(*param_type)); + } + + // Visit all of the statements in the body + self.visit_mut_statement_block(&mut function.body)?; + + // Update the dependency graph for this function + let current_item = QualifiedIdentifier::new( + self.current_module.unwrap(), + NamespacedIdentifier::Function(function.name), + ); + for (referenced_item, ref_type) in self.referenced.iter() { + let referenced_item = self.deps.add_node(*referenced_item); + self.deps.add_edge(current_item, referenced_item, *ref_type); + } + + // Restore the original references metadata + self.referenced = referenced; + // Restore the original lexical scope + self.locals.exit(); + + ControlFlow::Continue(()) + } + fn visit_mut_boundary_constraints( &mut self, body: &mut Vec, @@ -598,6 +661,7 @@ impl<'a> VisitMut for SemanticAnalysis<'a> { } // TODO: When we have non-evaluator functions, we must fetch the type in its signature here, // and store it as the type of the Call expression + expr.ty = fty.result(); } } else { self.has_type_errors = true; diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index ffb27629..1b119a8c 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -69,6 +69,11 @@ impl<'a> ConstantPropagation<'a> { self.visit_mut_evaluator_function(evaluator)?; } + // Visit all of the functions + for function in program.functions.values_mut() { + self.visit_mut_function(function)?; + } + // Visit all of the constraints self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?; self.visit_mut_integrity_constraints(&mut program.integrity_constraints) @@ -529,6 +534,7 @@ impl<'a> VisitMut for ConstantPropagation<'a> { } // This statement type is only present in the AST after inlining Statement::EnforceIf(_, _) => unreachable!(), + Statement::Return(ref mut expr) => self.visit_mut_expr(expr)?, } // If we have a non-empty buffer, then we are collapsing a let into the current block, diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 25e0f4ae..90162bfe 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -19,12 +19,12 @@ use super::constant_propagation; /// * Monomorphizing and inlining evaluators/functions at their call sites /// * Unrolling constraint comprehensions into a sequence of scalar constraints /// * Unrolling list comprehensions into a tree of `let` statements which end in -/// a vector expression (the implicit result of the tree). Each iteration of the -/// unrolled comprehension is reified as a value and bound to a variable so that -/// other transformations may refer to it directly. +/// a vector expression (the implicit result of the tree). Each iteration of the +/// unrolled comprehension is reified as a value and bound to a variable so that +/// other transformations may refer to it directly. /// * Rewriting aliases of top-level declarations to refer to those declarations directly /// * Removing let-bound variables which are unused, which is also used to clean up -/// after the aliasing rewrite mentioned above. +/// after the aliasing rewrite mentioned above. /// /// The trickiest transformation comes with inlining the body of evaluators at their /// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns @@ -75,6 +75,8 @@ pub struct Inlining<'a> { imported: HashMap, /// All evaluator functions in the program evaluators: HashMap, + /// All pure functions in the program + functions: HashMap, /// A set of identifiers for which accesses should be rewritten. /// /// When an identifier is in this set, it means it is a local alias for a trace column, @@ -97,6 +99,12 @@ impl<'p> Pass for Inlining<'p> { .map(|(k, v)| (*k, v.clone())) .collect(); + self.functions = program + .functions + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect(); + // We'll be referencing the trace configuration during inlining, so keep a copy of it self.trace = program.trace_columns.clone(); // Same with the random values @@ -187,6 +195,7 @@ impl<'a> Inlining<'a> { let_bound: Default::default(), imported: Default::default(), evaluators: Default::default(), + functions: Default::default(), rewrites: Default::default(), in_comprehension_constraint: false, next_ident: 0, @@ -297,6 +306,11 @@ impl<'a> Inlining<'a> { self.rewrite_expr(&mut expr)?; Ok(vec![Statement::Expr(expr)]) } + + Statement::Return(mut expr) => { + self.rewrite_expr(&mut expr)?; + Ok(vec![Statement::Return(expr)]) + } } } @@ -431,7 +445,7 @@ impl<'a> Inlining<'a> { other => unimplemented!("unhandled builtin: {}", other), } } else { - todo!("pure functions are not implemented yet") + self.expand_function_callsite(call) } } @@ -954,7 +968,7 @@ impl<'a> Inlining<'a> { // NOTE: We create a new nested scope for the parameters in order to avoid conflicting // with the root declarations eval_bindings.enter(); - self.populate_rewrites( + self.populate_rewrites_evaluator( &mut eval_bindings, call.args.as_slice(), evaluator.params.as_slice(), @@ -972,10 +986,121 @@ impl<'a> Inlining<'a> { Ok(evaluator.body) } + /// This function handles inlining pure function calls. + fn expand_function_callsite( + &mut self, + call: Call, + ) -> Result, SemanticAnalysisError> { + // The callee is guaranteed to be resolved and exist at this point + let callee = call + .callee + .resolved() + .expect("callee should have been resolved by now"); + + // We clone the function here as we will be modifying the body during the + // inlining process, and we must not modify the original + let mut function = self.functions.get(&callee).unwrap().clone(); + + // This will be the initial set of bindings visible within the function body + // + // This is distinct from `self.bindings` at this point, because the function doesn't + // inherit the caller's scope, it has an entirely new one. + let mut function_bindings = LexicalScope::default(); + + // Add all referenced (and thus imported) items from the function module + // + // NOTE: This will include constants, periodic columns, and other functions + for (qid, binding_ty) in self.imported.iter() { + if qid.module == callee.module { + function_bindings.insert(*qid.as_ref(), binding_ty.clone()); + } + } + + // Add random values, trace columns, and other root declarations to the set of + // bindings visible in the function body, _if_ the function is defined in the + // root module. + let is_function_in_root = callee.module == self.root; + if is_function_in_root { + if let Some(rv) = self.random_values.as_ref() { + function_bindings.insert( + rv.name, + BindingType::RandomValue(RandBinding::new( + rv.name.span(), + rv.name, + rv.size, + 0, + Type::Vector(rv.size), + )), + ); + for binding in rv.bindings.iter().copied() { + function_bindings.insert(binding.name, BindingType::RandomValue(binding)); + } + } + + for segment in self.trace.iter() { + function_bindings.insert( + segment.name, + BindingType::TraceColumn(TraceBinding { + span: segment.name.span(), + segment: segment.id, + name: Some(segment.name), + offset: 0, + size: segment.size, + ty: Type::Vector(segment.size), + }), + ); + for binding in segment.bindings.iter().copied() { + function_bindings.insert( + binding.name.unwrap(), + BindingType::TraceColumn(TraceBinding { + span: segment.name.span(), + segment: segment.id, + name: binding.name, + offset: binding.offset, + size: binding.size, + ty: binding.ty, + }), + ); + } + } + + for input in self.public_inputs.values() { + function_bindings.insert( + input.name, + BindingType::PublicInput(Type::Vector(input.size)), + ); + } + } + // TODO: validate arguments + + // Match call arguments to function parameters, populating the set of rewrites + // which should be performed on the inlined function body. + // + // NOTE: We create a new nested scope for the parameters in order to avoid conflicting + // with the root declarations + function_bindings.enter(); + self.populate_rewrites_function( + &mut function_bindings, + call.args.as_slice(), + function.params.as_slice(), + ); + + // While we're inlining the body, use the set of function bindings we built above + let prev_bindings = core::mem::replace(&mut self.bindings, function_bindings); + + // Expand the evaluator body into a block of statements + self.expand_statement_block(&mut function.body)?; + + // Restore the caller's bindings before we leave + self.bindings = prev_bindings; + + Ok(function.body) + } + /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function. /// /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself. - fn populate_rewrites( + fn populate_rewrites_evaluator( &mut self, eval_bindings: &mut LexicalScope, args: &[Expr], @@ -1160,6 +1285,31 @@ impl<'a> Inlining<'a> { } } + fn populate_rewrites_function( + &mut self, + function_bindings: &mut LexicalScope, + args: &[Expr], + params: &[(Identifier, Type)], + ) { + // Reset the rewrites set + self.rewrites.clear(); + + for (arg, param) in args.iter().zip(params.iter()) { + match arg { + Expr::SymbolAccess(access) => { + let mut binding_ty = Some(self.access_binding_type(access).unwrap()); + let binding_name = param.0; + // We can safely assume that there is a binding type available here, + // otherwise the semantic analysis pass missed something + let bt = binding_ty.take().unwrap(); + self.rewrites.insert(binding_name); + function_bindings.insert(binding_name, bt); + } + _ => todo!() + } + } + } + /// Returns a new [SymbolAccess] which should be used in place of `access` in the current scope. /// /// This function should only be called on accesses which have a trace column/param [BindingType], @@ -1445,6 +1595,7 @@ impl<'a> VisitMut for ApplyConstraintSelector<'a> { } Statement::EnforceAll(_) => unreachable!(), Statement::Expr(_) => ControlFlow::Continue(()), + Statement::Return(_) => ControlFlow::Continue(()), } } } @@ -1500,7 +1651,7 @@ where // is the effective value of the `let` tree. We will replace this // node if the callback we were given returns a new `Statement`. In // either case, we're done once we've handled the callback result. - Statement::Expr(ref mut value) => match callback(inliner, value) { + Statement::Expr(ref mut value) | Statement::Return(ref mut value) => match callback(inliner, value) { Ok(Some(replacement)) => { parent_block.pop(); parent_block.push(replacement);