From 058d56f96fa7d63ce14fa7b9d9e01f0b49d86bda Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Thu, 27 Jul 2023 02:43:17 +0000 Subject: [PATCH] feat: add support for functions --- parser/src/ast/declarations.rs | 66 ++++++++++++++++++++++ parser/src/ast/mod.rs | 6 ++ parser/src/ast/module.rs | 22 ++++++++ parser/src/ast/types.rs | 24 ++++++++ parser/src/lexer/mod.rs | 17 +++++- parser/src/lexer/tests/functions.rs | 83 ++++++++++++++++++++++++++++ parser/src/lexer/tests/mod.rs | 1 + parser/src/parser/grammar.lalrpop | 74 +++++++++++++++++++++++++ parser/src/parser/tests/functions.rs | 59 ++++++++++++++++++++ parser/src/parser/tests/mod.rs | 1 + parser/src/transforms/inlining.rs | 19 +++++-- 11 files changed, 367 insertions(+), 5 deletions(-) create mode 100644 parser/src/lexer/tests/functions.rs create mode 100644 parser/src/parser/tests/functions.rs diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index c39c7134..89cc4e8a 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -41,6 +41,10 @@ pub enum Declaration { /// /// Evaluator functions can be defined in any module of the program EvaluatorFunction(EvaluatorFunction), + /// A pure function definition + /// + /// Pure functions can be defined in any module of the program + Function(Function), /// A `periodic_columns` section declaration /// /// This may appear any number of times in the program, and may be declared in any module. @@ -523,3 +527,65 @@ impl PartialEq for EvaluatorFunction { self.name == other.name && self.params == other.params && self.body == other.body } } + +/// Functions take a group of expressions as parameters and returns a group of expressions. These +/// values can be a Felt, a Vector or a Matrix. Functions do not take trace bindings as parameters. +#[derive(Debug, Clone, Spanned)] +pub struct Function { + #[span] + pub span: SourceSpan, + pub name: Identifier, + pub params: Vec<(Identifier, Type)>, + pub return_types: Vec, + pub body: FunctionBody, +} +impl Function { + /// Creates a new function. + pub const fn new( + span: SourceSpan, + name: Identifier, + params: Vec<(Identifier, Type)>, + return_types: Vec, + body: FunctionBody, + ) -> Self { + Self { + span, + name, + params, + return_types, + body, + } + } +} +impl Eq for Function {} +impl PartialEq for Function { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.params == other.params + && self.return_types == other.return_types + && self.body == other.body + } +} + +#[derive(Debug, Clone, Spanned)] +pub struct FunctionBody { + #[span] + pub span: SourceSpan, + pub body: Vec, + pub return_values: Vec, +} +impl FunctionBody { + pub const fn new(span: SourceSpan, body: Vec, return_values: Vec) -> Self { + Self { + span, + body, + return_values, + } + } +} +impl Eq for FunctionBody {} +impl PartialEq for FunctionBody { + fn eq(&self, other: &Self) -> bool { + self.body == other.body && self.return_values == other.return_values + } +} diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index 5a9322eb..0e2796d5 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -71,6 +71,8 @@ pub struct Program { pub constants: BTreeMap, /// The set of used evaluator functions referenced in this program. pub evaluators: BTreeMap, + /// The set of used pure functions referenced in this program. + pub functions: BTreeMap, /// The set of used periodic columns referenced in this program. pub periodic_columns: BTreeMap, /// The set of public inputs defined in the root module @@ -115,6 +117,7 @@ impl Program { name, constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -288,6 +291,7 @@ impl PartialEq for Program { self.name == other.name && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values @@ -377,6 +381,8 @@ impl fmt::Display for Program { f.write_str("\n")?; } + // TODO: functions + Ok(()) } } diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index 393f93db..c9363ca9 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -54,6 +54,7 @@ pub struct Module { pub imports: BTreeMap, pub constants: BTreeMap, pub evaluators: BTreeMap, + pub functions: BTreeMap, pub periodic_columns: BTreeMap, pub public_inputs: BTreeMap, pub random_values: Option, @@ -79,6 +80,7 @@ impl Module { imports: Default::default(), constants: Default::default(), evaluators: Default::default(), + functions: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -121,6 +123,9 @@ impl Module { Declaration::EvaluatorFunction(evaluator) => { module.declare_evaluator(diagnostics, &mut names, evaluator)?; } + Declaration::Function(function) => { + module.declare_function(diagnostics, &mut names, function)?; + } Declaration::PeriodicColumns(mut columns) => { for column in columns.drain(..) { module.declare_periodic_column(diagnostics, &mut names, column)?; @@ -395,6 +400,22 @@ impl Module { Ok(()) } + fn declare_function( + &mut self, + diagnostics: &DiagnosticsHandler, + names: &mut HashSet, + function: Function, + ) -> Result<(), SemanticAnalysisError> { + if let Some(prev) = names.replace(NamespacedIdentifier::Function(function.name)) { + conflicting_declaration(diagnostics, "function", prev.span(), function.name.span()); + return Err(SemanticAnalysisError::NameConflict(function.name.span())); + } + + self.functions.insert(function.name, function); + + Ok(()) + } + fn declare_periodic_column( &mut self, diagnostics: &DiagnosticsHandler, @@ -621,6 +642,7 @@ impl PartialEq for Module { && self.imports == other.imports && self.constants == other.constants && self.evaluators == other.evaluators + && self.functions == other.functions && self.periodic_columns == other.periodic_columns && self.public_inputs == other.public_inputs && self.random_values == other.random_values diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index a5cdfd4f..8f28dc2d 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -79,6 +79,30 @@ impl fmt::Display for Type { } } +// impl std::str::FromStr for Type { +// type Err = (); +// fn from_str(s: &str) -> Result { +// let s = s.trim(); +// if s == "felt" { +// return Some(Type::Felt); +// } +// // Attempt to parse as Vector or Matrix +// if s.starts_with("felt[") && s.ends_with("]") { +// let contents = &s[5..s.len()-1]; // Extract inner content without brackets +// let numbers: Vec = contents.split(',') +// .map(str::trim) +// .filter_map(|num| num.parse().ok()) +// .collect(); +// match numbers.len() { +// 1 => return Some(Type::Vector(numbers[0])), +// 2 => return Some(Type::Matrix(numbers[0], numbers[1])), +// _ => return None, +// } +// } +// None +// } +// } + /// Represents the type signature of a function #[derive(Debug, Clone, PartialEq, Eq)] pub enum FunctionType { diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index cc9af75d..da19a9f4 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -113,6 +113,8 @@ pub enum Token { RandomValues, /// Keyword to declare the evaluator function section in the AIR constraints module. Ev, + /// Keyword to declare the function section in the AIR constraints module. + Fn, // BOUNDARY CONSTRAINT KEYWORDS // -------------------------------------------------------------------------------------------- @@ -137,9 +139,11 @@ pub enum Token { // -------------------------------------------------------------------------------------------- /// Keyword to signify that a constraint needs to be enforced Enf, + Return, Match, Case, When, + Felt, // PUNCTUATION // -------------------------------------------------------------------------------------------- @@ -161,6 +165,7 @@ pub enum Token { Ampersand, Bar, Bang, + Arrow, } impl Token { pub fn from_keyword_or_ident(s: &str) -> Self { @@ -177,6 +182,8 @@ impl Token { "periodic_columns" => Self::PeriodicColumns, "random_values" => Self::RandomValues, "ev" => Self::Ev, + "fn" => Self::Fn, + "felt" => Self::Felt, "boundary_constraints" => Self::BoundaryConstraints, "integrity_constraints" => Self::IntegrityConstraints, "first" => Self::First, @@ -184,6 +191,7 @@ impl Token { "for" => Self::For, "in" => Self::In, "enf" => Self::Enf, + "return" => Self::Return, "match" => Self::Match, "case" => Self::Case, "when" => Self::When, @@ -247,6 +255,8 @@ impl fmt::Display for Token { Self::PeriodicColumns => write!(f, "periodic_columns"), Self::RandomValues => write!(f, "random_values"), Self::Ev => write!(f, "ev"), + Self::Fn => write!(f, "fn"), + Self::Felt => write!(f, "felt"), Self::BoundaryConstraints => write!(f, "boundary_constraints"), Self::First => write!(f, "first"), Self::Last => write!(f, "last"), @@ -254,6 +264,7 @@ impl fmt::Display for Token { Self::For => write!(f, "for"), Self::In => write!(f, "in"), Self::Enf => write!(f, "enf"), + Self::Return => write!(f, "return"), Self::Match => write!(f, "match"), Self::Case => write!(f, "case"), Self::When => write!(f, "when"), @@ -275,6 +286,7 @@ impl fmt::Display for Token { Self::Ampersand => write!(f, "&"), Self::Bar => write!(f, "|"), Self::Bang => write!(f, "!"), + Self::Arrow => write!(f, "->"), } } } @@ -486,7 +498,10 @@ where ']' => pop!(self, Token::RBracket), '=' => pop!(self, Token::Equal), '+' => pop!(self, Token::Plus), - '-' => pop!(self, Token::Minus), + '-' => match self.peek() { + '>' => pop2!(self, Token::Arrow), + _ => pop!(self, Token::Minus), + }, '*' => pop!(self, Token::Star), '^' => pop!(self, Token::Caret), '&' => pop!(self, Token::Ampersand), diff --git a/parser/src/lexer/tests/functions.rs b/parser/src/lexer/tests/functions.rs new file mode 100644 index 00000000..4e5e49a3 --- /dev/null +++ b/parser/src/lexer/tests/functions.rs @@ -0,0 +1,83 @@ +use super::{expect_valid_tokenization, Symbol, Token}; + +// FUNCTION VALID TOKENIZATION +// ================================================================================================ + +#[test] +fn fn_with_scalars() { + let source = "fn fn_name(a: felt, b: felt) -> felt: + return a + b"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::Colon, + Token::Return, + Token::Ident(Symbol::intern("a")), + Token::Plus, + Token::Ident(Symbol::intern("b")), + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} + +#[test] +fn fn_with_vectors() { + let source = "fn fn_name(a: felt[12], b: felt[12]) -> felt[12]: + return [x + y for x, y in (a, b)]"; + + let tokens = [ + Token::Fn, + Token::FunctionIdent(Symbol::intern("fn_name")), + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::Colon, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::RParen, + Token::Arrow, + Token::Felt, + Token::LBracket, + Token::Num(12), + Token::RBracket, + Token::Colon, + Token::Return, + Token::LBracket, + Token::Ident(Symbol::intern("x")), + Token::Plus, + Token::Ident(Symbol::intern("y")), + Token::For, + Token::Ident(Symbol::intern("x")), + Token::Comma, + Token::Ident(Symbol::intern("y")), + Token::In, + Token::LParen, + Token::Ident(Symbol::intern("a")), + Token::Comma, + Token::Ident(Symbol::intern("b")), + Token::RParen, + Token::RBracket, + ]; + + expect_valid_tokenization(source, tokens.to_vec()); +} diff --git a/parser/src/lexer/tests/mod.rs b/parser/src/lexer/tests/mod.rs index 0426e68a..a312c007 100644 --- a/parser/src/lexer/tests/mod.rs +++ b/parser/src/lexer/tests/mod.rs @@ -6,6 +6,7 @@ mod arithmetic_ops; mod boundary_constraints; mod constants; mod evaluator_functions; +mod functions; mod identifiers; mod list_comprehension; mod modules; diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index ac0736a8..c238fa9c 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -75,6 +75,7 @@ Declaration: Declaration = { PeriodicColumns => Declaration::PeriodicColumns(<>), RandomValues => Declaration::RandomValues(<>), EvaluatorFunction => Declaration::EvaluatorFunction(<>), + Function => Declaration::Function(<>), => Declaration::Trace(Span::new(span!(l, r), trace)), => Declaration::PublicInputs(<>), => Declaration::BoundaryConstraints(<>), @@ -256,6 +257,66 @@ EvaluatorSegmentBindings: (SourceSpan, Vec>) = { "[" "]" => (span!(l, r), vec![]), } +// FUNCTIONS +// ================================================================================================ + +Function: Function = { + "fn" "(" ")" "->" + ":" + => Function::new(span!(l, r), name, params, return_types, body) +} + +FunctionBindings: Vec<(Identifier, Type)> = { + > => params, +} + +FunctionBinding: (Identifier, Type) = { + ":" => (name, ty), +} + +FunctionReturnBindingTypes: Vec = { + > => types, +} + +// TODO: refactor +FunctionBindingType: Type = { + "felt" => Type::Felt, + "felt" => Type::Vector(size as usize), + "felt" "[" "," "]" => Type::Matrix(row_size as usize, col_size as usize), +} + +FunctionBody: FunctionBody = { + // TODO: remove return keyword + "return" =>? { + if let Some(let_stmts) = let_stmts { + if let_stmts.iter().any(|stmt| !matches!(stmt, Statement::Let(_))) { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid function definition") + .with_primary_label(span!(l, r), "only let statements are allowed in function definitions") + .emit(); + return Err(ParseError::Failed.into()); + } + Ok(FunctionBody::new(span!(l, r), let_stmts, vec![return_value])) + } else { + Ok(FunctionBody::new(span!(l, r), vec![], vec![return_value])) + } + }, + "return" > =>? { + if let Some(let_stmts) = let_stmts { + if let_stmts.iter().any(|stmt| !matches!(stmt, Statement::Let(_))) { + diagnostics.diagnostic(Severity::Error) + .with_message("invalid function definition") + .with_primary_label(span!(l, r), "only let statements are allowed in function definitions") + .emit(); + return Err(ParseError::Failed.into()); + } + Ok(FunctionBody::new(span!(l, r), let_stmts, return_values)) + } else { + Ok(FunctionBody::new(span!(l, r), vec![], return_values)) + } + } +} + // BOUNDARY CONSTRAINTS // ================================================================================================ @@ -533,6 +594,15 @@ Matrix: Vec> = { Vector>, } +Tuple: Vec = { + "(" "," )*> ")" => { + let mut v = v; + v.insert(0, v2); + v.insert(0, v1); + v + } +}; + Size: u64 = { "[" "]" => <> } @@ -591,10 +661,13 @@ extern { "last" => Token::Last, "integrity_constraints" => Token::IntegrityConstraints, "ev" => Token::Ev, + "fn" => Token::Fn, "enf" => Token::Enf, + "return" => Token::Return, "match" => Token::Match, "case" => Token::Case, "when" => Token::When, + "felt" => Token::Felt, "'" => Token::Quote, "=" => Token::Equal, "+" => Token::Plus, @@ -613,5 +686,6 @@ extern { ")" => Token::RParen, "." => Token::Dot, ".." => Token::DotDot, + "->" => Token::Arrow, } } diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs new file mode 100644 index 00000000..a681a2ac --- /dev/null +++ b/parser/src/parser/tests/functions.rs @@ -0,0 +1,59 @@ +use miden_diagnostics::SourceSpan; + +use crate::ast::*; + +use super::ParseTest; + +// FUNCTIONS +// ================================================================================================ + +#[test] +fn fn_with_scalars() { + let source = " + mod test + + fn fn_with_scalars(a: felt, b: felt) -> felt: + return a + b"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_scalars), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_scalars), + vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], + vec![Type::Felt], + FunctionBody { + span: SourceSpan::UNKNOWN, + body: vec![], + return_values: vec![expr!(add!(access!(a), access!(b)))], + }, + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn fn_with_vectors() { + let source = " + mod test + + fn fn_with_vectors(a: felt[12], b: felt[12]) -> felt[12]: + return [x + y for (x, y) in (a, b)]"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.functions.insert( + ident!(fn_with_vectors), + Function::new( + SourceSpan::UNKNOWN, + function_ident!(fn_with_vectors), + vec![ + (ident!(a), Type::Vector(12)), + (ident!(b), Type::Vector(12)), + ], + vec![Type::Vector(12)], + FunctionBody { span: SourceSpan::UNKNOWN, body: vec![], return_values: vec![lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!(access!(x), access!(y))).into()] } + ), + ); + ParseTest::new().expect_module_ast(source, expected); +} diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index cd614fb9..4c54cda5 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -599,6 +599,7 @@ mod calls; mod constant_propagation; mod constants; mod evaluators; +mod functions; mod identifiers; mod inlining; mod integrity_constraints; diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 25e0f4ae..ee160453 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -19,12 +19,12 @@ use super::constant_propagation; /// * Monomorphizing and inlining evaluators/functions at their call sites /// * Unrolling constraint comprehensions into a sequence of scalar constraints /// * Unrolling list comprehensions into a tree of `let` statements which end in -/// a vector expression (the implicit result of the tree). Each iteration of the -/// unrolled comprehension is reified as a value and bound to a variable so that -/// other transformations may refer to it directly. +/// a vector expression (the implicit result of the tree). Each iteration of the +/// unrolled comprehension is reified as a value and bound to a variable so that +/// other transformations may refer to it directly. /// * Rewriting aliases of top-level declarations to refer to those declarations directly /// * Removing let-bound variables which are unused, which is also used to clean up -/// after the aliasing rewrite mentioned above. +/// after the aliasing rewrite mentioned above. /// /// The trickiest transformation comes with inlining the body of evaluators at their /// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns @@ -75,6 +75,8 @@ pub struct Inlining<'a> { imported: HashMap, /// All evaluator functions in the program evaluators: HashMap, + /// All functions in the program + functions: HashMap, /// A set of identifiers for which accesses should be rewritten. /// /// When an identifier is in this set, it means it is a local alias for a trace column, @@ -187,6 +189,7 @@ impl<'a> Inlining<'a> { let_bound: Default::default(), imported: Default::default(), evaluators: Default::default(), + functions: Default::default(), rewrites: Default::default(), in_comprehension_constraint: false, next_ident: 0, @@ -972,6 +975,14 @@ impl<'a> Inlining<'a> { Ok(evaluator.body) } + /// This function handles inlining evaluator function calls. + fn expand_function_callsite( + &mut self, + _call: Call, + ) -> Result, SemanticAnalysisError> { + todo!() + } + /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function. /// /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself.