Skip to content

Commit

Permalink
feat: add support for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Jul 27, 2023
1 parent 5f7411b commit 058d56f
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 5 deletions.
66 changes: 66 additions & 0 deletions parser/src/ast/declarations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum Declaration {
///
/// Evaluator functions can be defined in any module of the program
EvaluatorFunction(EvaluatorFunction),
/// A pure function definition
///
/// Pure functions can be defined in any module of the program
Function(Function),
/// A `periodic_columns` section declaration
///
/// This may appear any number of times in the program, and may be declared in any module.
Expand Down Expand Up @@ -523,3 +527,65 @@ impl PartialEq for EvaluatorFunction {
self.name == other.name && self.params == other.params && self.body == other.body
}
}

/// Functions take a group of expressions as parameters and returns a group of expressions. These
/// values can be a Felt, a Vector or a Matrix. Functions do not take trace bindings as parameters.
#[derive(Debug, Clone, Spanned)]
pub struct Function {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub params: Vec<(Identifier, Type)>,
pub return_types: Vec<Type>,
pub body: FunctionBody,
}
impl Function {
/// Creates a new function.
pub const fn new(
span: SourceSpan,
name: Identifier,
params: Vec<(Identifier, Type)>,
return_types: Vec<Type>,
body: FunctionBody,
) -> Self {
Self {
span,
name,
params,
return_types,
body,
}
}
}
impl Eq for Function {}
impl PartialEq for Function {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.params == other.params
&& self.return_types == other.return_types
&& self.body == other.body
}
}

#[derive(Debug, Clone, Spanned)]
pub struct FunctionBody {
#[span]
pub span: SourceSpan,
pub body: Vec<Statement>,
pub return_values: Vec<Expr>,
}
impl FunctionBody {
pub const fn new(span: SourceSpan, body: Vec<Statement>, return_values: Vec<Expr>) -> Self {
Self {
span,
body,
return_values,
}
}
}
impl Eq for FunctionBody {}
impl PartialEq for FunctionBody {
fn eq(&self, other: &Self) -> bool {
self.body == other.body && self.return_values == other.return_values
}
}
6 changes: 6 additions & 0 deletions parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub struct Program {
pub constants: BTreeMap<QualifiedIdentifier, Constant>,
/// The set of used evaluator functions referenced in this program.
pub evaluators: BTreeMap<QualifiedIdentifier, EvaluatorFunction>,
/// The set of used pure functions referenced in this program.
pub functions: BTreeMap<QualifiedIdentifier, Function>,
/// The set of used periodic columns referenced in this program.
pub periodic_columns: BTreeMap<QualifiedIdentifier, PeriodicColumn>,
/// The set of public inputs defined in the root module
Expand Down Expand Up @@ -115,6 +117,7 @@ impl Program {
name,
constants: Default::default(),
evaluators: Default::default(),
functions: Default::default(),
periodic_columns: Default::default(),
public_inputs: Default::default(),
random_values: None,
Expand Down Expand Up @@ -288,6 +291,7 @@ impl PartialEq for Program {
self.name == other.name
&& self.constants == other.constants
&& self.evaluators == other.evaluators
&& self.functions == other.functions
&& self.periodic_columns == other.periodic_columns
&& self.public_inputs == other.public_inputs
&& self.random_values == other.random_values
Expand Down Expand Up @@ -377,6 +381,8 @@ impl fmt::Display for Program {
f.write_str("\n")?;
}

// TODO: functions

Ok(())
}
}
Expand Down
22 changes: 22 additions & 0 deletions parser/src/ast/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub struct Module {
pub imports: BTreeMap<ModuleId, Import>,
pub constants: BTreeMap<Identifier, Constant>,
pub evaluators: BTreeMap<Identifier, EvaluatorFunction>,
pub functions: BTreeMap<Identifier, Function>,
pub periodic_columns: BTreeMap<Identifier, PeriodicColumn>,
pub public_inputs: BTreeMap<Identifier, PublicInput>,
pub random_values: Option<RandomValues>,
Expand All @@ -79,6 +80,7 @@ impl Module {
imports: Default::default(),
constants: Default::default(),
evaluators: Default::default(),
functions: Default::default(),
periodic_columns: Default::default(),
public_inputs: Default::default(),
random_values: None,
Expand Down Expand Up @@ -121,6 +123,9 @@ impl Module {
Declaration::EvaluatorFunction(evaluator) => {
module.declare_evaluator(diagnostics, &mut names, evaluator)?;
}
Declaration::Function(function) => {
module.declare_function(diagnostics, &mut names, function)?;
}
Declaration::PeriodicColumns(mut columns) => {
for column in columns.drain(..) {
module.declare_periodic_column(diagnostics, &mut names, column)?;
Expand Down Expand Up @@ -395,6 +400,22 @@ impl Module {
Ok(())
}

fn declare_function(
&mut self,
diagnostics: &DiagnosticsHandler,
names: &mut HashSet<NamespacedIdentifier>,
function: Function,
) -> Result<(), SemanticAnalysisError> {
if let Some(prev) = names.replace(NamespacedIdentifier::Function(function.name)) {
conflicting_declaration(diagnostics, "function", prev.span(), function.name.span());
return Err(SemanticAnalysisError::NameConflict(function.name.span()));
}

self.functions.insert(function.name, function);

Ok(())
}

fn declare_periodic_column(
&mut self,
diagnostics: &DiagnosticsHandler,
Expand Down Expand Up @@ -621,6 +642,7 @@ impl PartialEq for Module {
&& self.imports == other.imports
&& self.constants == other.constants
&& self.evaluators == other.evaluators
&& self.functions == other.functions
&& self.periodic_columns == other.periodic_columns
&& self.public_inputs == other.public_inputs
&& self.random_values == other.random_values
Expand Down
24 changes: 24 additions & 0 deletions parser/src/ast/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ impl fmt::Display for Type {
}
}

// impl std::str::FromStr for Type {
// type Err = ();
// fn from_str(s: &str) -> Result<Self, Self::Err> {
// let s = s.trim();
// if s == "felt" {
// return Some(Type::Felt);
// }
// // Attempt to parse as Vector or Matrix
// if s.starts_with("felt[") && s.ends_with("]") {
// let contents = &s[5..s.len()-1]; // Extract inner content without brackets
// let numbers: Vec<usize> = contents.split(',')
// .map(str::trim)
// .filter_map(|num| num.parse().ok())
// .collect();
// match numbers.len() {
// 1 => return Some(Type::Vector(numbers[0])),
// 2 => return Some(Type::Matrix(numbers[0], numbers[1])),
// _ => return None,
// }
// }
// None
// }
// }

/// Represents the type signature of a function
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FunctionType {
Expand Down
17 changes: 16 additions & 1 deletion parser/src/lexer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ pub enum Token {
RandomValues,
/// Keyword to declare the evaluator function section in the AIR constraints module.
Ev,
/// Keyword to declare the function section in the AIR constraints module.
Fn,

// BOUNDARY CONSTRAINT KEYWORDS
// --------------------------------------------------------------------------------------------
Expand All @@ -137,9 +139,11 @@ pub enum Token {
// --------------------------------------------------------------------------------------------
/// Keyword to signify that a constraint needs to be enforced
Enf,
Return,
Match,
Case,
When,
Felt,

// PUNCTUATION
// --------------------------------------------------------------------------------------------
Expand All @@ -161,6 +165,7 @@ pub enum Token {
Ampersand,
Bar,
Bang,
Arrow,
}
impl Token {
pub fn from_keyword_or_ident(s: &str) -> Self {
Expand All @@ -177,13 +182,16 @@ impl Token {
"periodic_columns" => Self::PeriodicColumns,
"random_values" => Self::RandomValues,
"ev" => Self::Ev,
"fn" => Self::Fn,
"felt" => Self::Felt,
"boundary_constraints" => Self::BoundaryConstraints,
"integrity_constraints" => Self::IntegrityConstraints,
"first" => Self::First,
"last" => Self::Last,
"for" => Self::For,
"in" => Self::In,
"enf" => Self::Enf,
"return" => Self::Return,
"match" => Self::Match,
"case" => Self::Case,
"when" => Self::When,
Expand Down Expand Up @@ -247,13 +255,16 @@ impl fmt::Display for Token {
Self::PeriodicColumns => write!(f, "periodic_columns"),
Self::RandomValues => write!(f, "random_values"),
Self::Ev => write!(f, "ev"),
Self::Fn => write!(f, "fn"),
Self::Felt => write!(f, "felt"),
Self::BoundaryConstraints => write!(f, "boundary_constraints"),
Self::First => write!(f, "first"),
Self::Last => write!(f, "last"),
Self::IntegrityConstraints => write!(f, "integrity_constraints"),
Self::For => write!(f, "for"),
Self::In => write!(f, "in"),
Self::Enf => write!(f, "enf"),
Self::Return => write!(f, "return"),
Self::Match => write!(f, "match"),
Self::Case => write!(f, "case"),
Self::When => write!(f, "when"),
Expand All @@ -275,6 +286,7 @@ impl fmt::Display for Token {
Self::Ampersand => write!(f, "&"),
Self::Bar => write!(f, "|"),
Self::Bang => write!(f, "!"),
Self::Arrow => write!(f, "->"),
}
}
}
Expand Down Expand Up @@ -486,7 +498,10 @@ where
']' => pop!(self, Token::RBracket),
'=' => pop!(self, Token::Equal),
'+' => pop!(self, Token::Plus),
'-' => pop!(self, Token::Minus),
'-' => match self.peek() {
'>' => pop2!(self, Token::Arrow),
_ => pop!(self, Token::Minus),
},
'*' => pop!(self, Token::Star),
'^' => pop!(self, Token::Caret),
'&' => pop!(self, Token::Ampersand),
Expand Down
83 changes: 83 additions & 0 deletions parser/src/lexer/tests/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use super::{expect_valid_tokenization, Symbol, Token};

// FUNCTION VALID TOKENIZATION
// ================================================================================================

#[test]
fn fn_with_scalars() {
let source = "fn fn_name(a: felt, b: felt) -> felt:
return a + b";

let tokens = [
Token::Fn,
Token::FunctionIdent(Symbol::intern("fn_name")),
Token::LParen,
Token::Ident(Symbol::intern("a")),
Token::Colon,
Token::Felt,
Token::Comma,
Token::Ident(Symbol::intern("b")),
Token::Colon,
Token::Felt,
Token::RParen,
Token::Arrow,
Token::Felt,
Token::Colon,
Token::Return,
Token::Ident(Symbol::intern("a")),
Token::Plus,
Token::Ident(Symbol::intern("b")),
];

expect_valid_tokenization(source, tokens.to_vec());
}

#[test]
fn fn_with_vectors() {
let source = "fn fn_name(a: felt[12], b: felt[12]) -> felt[12]:
return [x + y for x, y in (a, b)]";

let tokens = [
Token::Fn,
Token::FunctionIdent(Symbol::intern("fn_name")),
Token::LParen,
Token::Ident(Symbol::intern("a")),
Token::Colon,
Token::Felt,
Token::LBracket,
Token::Num(12),
Token::RBracket,
Token::Comma,
Token::Ident(Symbol::intern("b")),
Token::Colon,
Token::Felt,
Token::LBracket,
Token::Num(12),
Token::RBracket,
Token::RParen,
Token::Arrow,
Token::Felt,
Token::LBracket,
Token::Num(12),
Token::RBracket,
Token::Colon,
Token::Return,
Token::LBracket,
Token::Ident(Symbol::intern("x")),
Token::Plus,
Token::Ident(Symbol::intern("y")),
Token::For,
Token::Ident(Symbol::intern("x")),
Token::Comma,
Token::Ident(Symbol::intern("y")),
Token::In,
Token::LParen,
Token::Ident(Symbol::intern("a")),
Token::Comma,
Token::Ident(Symbol::intern("b")),
Token::RParen,
Token::RBracket,
];

expect_valid_tokenization(source, tokens.to_vec());
}
1 change: 1 addition & 0 deletions parser/src/lexer/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod arithmetic_ops;
mod boundary_constraints;
mod constants;
mod evaluator_functions;
mod functions;
mod identifiers;
mod list_comprehension;
mod modules;
Expand Down
Loading

0 comments on commit 058d56f

Please sign in to comment.