Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pattern support #41

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 47 additions & 80 deletions src/lean/mod.rs
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
mod builtin;
pub mod indent;
mod syntax;
mod pattern;

use std::collections::{HashMap, HashSet};

@@ -1030,7 +1031,6 @@ impl LeanEmitter {
_ => panic!("member access lhs is not a struct or tuple: {lhs_expr_ty:?}"),
}
}

HirExpression::Cast(cast) => {
let source = self.emit_expr(ind, cast.lhs)?;
let target_type = self.emit_fully_qualified_type(&cast.r#type);
@@ -1066,50 +1066,62 @@ impl LeanEmitter {
HirExpression::Lambda(lambda) => {
let ret_type = self.emit_fully_qualified_type(&lambda.return_type);

let arg_strs: Vec<String> = lambda
// Divide the parameters into simple and complex parameters, where simple parameters are parameters that can be expressed as a single let or let mut binding.
let (simple_params, complex_params): (Vec<_>, Vec<_>) = lambda
.parameters
.iter()
.map(|(pattern, ty)| {
let pattern_str = self.emit_pattern(pattern)?;
let typ = self.emit_fully_qualified_type(ty);
Ok(format!("{pattern_str} : {typ}"))
})
.try_collect()?;
let args = arg_strs.join(", ");
let captures = lambda
.captures
.iter()
.map(|capture| {
let capture_type =
self.context.def_interner.definition_type(capture.ident.id);
let capture_type = self.emit_fully_qualified_type(&capture_type);
let name = self.context.def_interner.definition_name(capture.ident.id);

format!("{name} : {capture_type}")
.enumerate()
.partition_map(|(param_idx, (pat, param_typ))| {
let rhs = self.emit_fully_qualified_type(param_typ);
if let Some((_, lhs)) =
pattern::try_format_simple_pattern(pat, "", self)
{
// If the parameter is simple, we can directly use the ident as the lhs.
let lhs = self.context.def_interner.definition_name(lhs.id);
itertools::Either::Left(format!("{lhs} : {rhs}"))
} else {
// If the parameter is complex, we need to generate a fresh binding for it.
let lhs = format!("param#{param_idx}");
itertools::Either::Right((pat.clone(), lhs, rhs))
}
});
// Convert the parameters into strings.
let params_str = complex_params.iter()
.map(|(_, lhs, rhs)| {
format!("{lhs} : {rhs}")
})
.chain(simple_params)
.join(", ");

// Convert the complex parameters into a series of let (mut) bindings.
let pattern_stmts_str = complex_params.iter().map(|(pat, lhs, _)| {
pattern::format_pattern(pat, lhs, self).join(";\n")
}).join(";\n");
let body = self.emit_expr(ind, lambda.body)?;
// Prepend the body with the appropriate block of let (mut) bindings if there are any complex parameters.
let body = if pattern_stmts_str.is_empty() {
body
} else {
ind.run(format!("{{\n{pattern_stmts_str};\n{{\n{body}\n}}\n}}"))
};

syntax::expr::format_lambda(&captures, &args, &body, &ret_type)
syntax::expr::format_lambda(&params_str, &body, &ret_type)
}
HirExpression::MethodCall(_) => {
HirExpression::MethodCall(..) => {
panic!("Method call expressions should not exist after type checking")
}
HirExpression::Comptime(_) => {
HirExpression::Comptime(..) => {
panic!("Comptime expressions should not exist after compilation is done")
}
HirExpression::Quote(_) => {
HirExpression::Quote(..) => {
panic!("Quote expressions should not exist after macro resolution")
}
HirExpression::Unquote(_) => {
HirExpression::Unquote(..) => {
panic!("Unquote expressions should not exist after macro resolution")
}

HirExpression::Error => {
panic!("Encountered error expression where none should exist")
}
HirExpression::Unsafe(_) => panic!("unsafe expressions not supported yet"),
HirExpression::Unsafe(..) => panic!("Unsafe expressions not supported yet"),
};

Ok(expression)
@@ -1157,20 +1169,18 @@ impl LeanEmitter {
let result = match stmt_data {
HirStatement::Expression(expr) => self.emit_expr(ind, expr)?,
HirStatement::Let(lets) => {
let binding_type = self.emit_fully_qualified_type(&lets.r#type);
let bound_expr = self.emit_expr(ind, lets.expression)?;
let name = self.emit_pattern(&lets.pattern)?;
// [TODO] proper pattern support
syntax::stmt::format_let_in(&name, &binding_type, &bound_expr)
if let Some((simple_stmt, _)) = pattern::try_format_simple_pattern(&lets.pattern, &bound_expr, self) {
simple_stmt
} else {
let pat_rhs = "param#0";
let mut stmts = vec![syntax::stmt::format_let_in(pat_rhs, &bound_expr)];
stmts.extend(pattern::format_pattern(&lets.pattern, pat_rhs, self));
stmts.join(";\n")
}
}
HirStatement::Constrain(constraint) => {
let constraint_expr = self.emit_expr(ind, constraint.0)?;
// [TODO] what to do with asserts with prints?
let _print_expr = if let Some(expr) = constraint.2 {
Some(self.emit_expr(ind, expr)?)
} else {
None
};

syntax::expr::format_builtin_call(
builtin::ASSERT_BUILTIN_NAME.into(),
@@ -1271,49 +1281,6 @@ impl LeanEmitter {
Ok(result)
}

/// Emits the Lean code corresponding to a Noir pattern.
///
/// # Errors
///
/// - [`Error`] if the extraction process fails for any reason.
pub fn emit_pattern(&self, pattern: &HirPattern) -> Result<String> {
let result = match pattern {
HirPattern::Identifier(id) => {
self.context.def_interner.definition_name(id.id).to_string()
}
HirPattern::Mutable(pattern, _) => {
let child_pattern = self.emit_pattern(pattern.as_ref())?;
format!("mut {child_pattern}")
}
HirPattern::Tuple(patterns, _) => {
let pattern_strs: Vec<String> = patterns
.iter()
.map(|pattern| self.emit_pattern(pattern))
.try_collect()?;
let patterns_str = pattern_strs.join(", ");

format!("({patterns_str})")
}
HirPattern::Struct(struct_ty, patterns, _) => {
let ty_str = self.emit_fully_qualified_type(struct_ty);

let pattern_strs: Vec<String> = patterns
.iter()
.map(|(pat_name, pat_expr)| {
let child_pat = self.emit_pattern(pat_expr)?;

Ok(format!("{pat_name}: {child_pat}"))
})
.try_collect()?;
let patterns_str = pattern_strs.join(", ");

format!("{ty_str} {{{patterns_str}}}")
}
};

Ok(result)
}

/// Emits the Lean source code corresponding to a Noir literal.
///
/// # Errors
116 changes: 116 additions & 0 deletions src/lean/pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use noirc_frontend::{ast::Ident, hir_def::{expr::HirIdent, stmt::HirPattern}, Type};

#[derive(Clone, Debug)]
enum PatType {
Tuple(usize),
Struct { struct_type: Type, field: Ident },
}

/// The context of a pattern which contains all the necessary information to convert a nested `HirPattern::Identifier` pattern into a let (mut) binding.
/// This means, it contains the stack of tuple and struct patterns that the identifier is nested in, and whether the identifier is mutable.
#[derive(Clone, Debug)]
struct PatCtx {
stack: Vec<PatType>,
is_mut: bool,
}

impl Default for PatCtx {
fn default() -> Self {
Self { stack: Vec::new(), is_mut: false }
}
}

impl PatCtx {
fn push(&mut self, pat_type: PatType) {
self.stack.push(pat_type);
}

fn pop(&mut self) {
self.stack.pop();
}

fn set_mut(&mut self, is_mut: bool) {
self.is_mut = is_mut;
}
}

#[derive(Clone, Debug)]
struct PatRes(/** lhs **/ HirIdent, /** rhs **/ PatCtx);

fn parse_pattern(pat: &HirPattern, ctx: &mut PatCtx) -> Vec<PatRes> {
match pat {
HirPattern::Identifier(hir_ident) => {
vec![PatRes(hir_ident.clone(), ctx.clone())]
},
// A `mut` pattern makes the whole sub-pattern mutable.
// Note that nested mut patterns are unnecessary, and they are forbidden by the compiler.
HirPattern::Mutable(sub_pat, ..) => {
ctx.set_mut(true);
let res = parse_pattern(sub_pat, ctx);
ctx.set_mut(false);
res
},
HirPattern::Tuple(sub_pats, ..) => {
let mut res = Vec::new();
for (i, pat) in sub_pats.iter().enumerate() {
ctx.push(PatType::Tuple(i));
res.extend(parse_pattern(pat, ctx));
ctx.pop();
}
res
},
HirPattern::Struct(struct_type, sub_pats, ..) => {
let mut res = Vec::new();
for (ident, pat) in sub_pats.iter() {
ctx.push(PatType::Struct { struct_type: struct_type.clone(), field: ident.clone() });
res.extend(parse_pattern(pat, ctx));
ctx.pop();
}
res
},
}
}

/// Emits the Lean code corresponding to a Noir pattern as a single `let` or `let mut` binding, along with the `HirIdent` at the lhs of the pattern.
/// Returns `None` if the pattern is not simple enough to be expressed as a single binding.
pub(super) fn try_format_simple_pattern(pat: &HirPattern, pat_rhs: &str, emitter: &super::LeanEmitter) -> Option<(String, HirIdent)> {
match pat {
HirPattern::Identifier(ident) => {
format_pattern(pat, pat_rhs, emitter).pop().map(|pat| (pat, ident.clone()))
}
HirPattern::Mutable(sub_pat, ..) => {
if let HirPattern::Identifier(ident) = sub_pat.as_ref() {
format_pattern(pat, pat_rhs, emitter).pop().map(|pat| (pat, ident.clone()))
} else {
None
}
}
_ => None
}
}

/// Emits the Lean code corresponding to a Noir pattern as a series of `let` or `let mut` bindings.
pub(super) fn format_pattern(pat: &HirPattern, pat_rhs: &str, emitter: &super::LeanEmitter) -> Vec<String> {
let mut ctx = PatCtx::default();
parse_pattern(pat, &mut ctx).into_iter().map(|pat_res| {
let PatRes(id, ctx) = pat_res;
let lhs = emitter.context.def_interner.definition_name(id.id).to_string();
let mut rhs = pat_rhs.to_string();
for pat_type in ctx.stack {
match pat_type {
PatType::Tuple(i) => {
rhs = super::syntax::expr::format_tuple_access(&rhs, &format!("{}", i));
},
PatType::Struct { struct_type, field } => {
let struct_name = emitter.emit_fully_qualified_type(&struct_type);
rhs = super::syntax::expr::format_member_access(&struct_name, &rhs, &field.to_string());
},
}
}
if ctx.is_mut {
super::syntax::stmt::format_let_mut_in(&lhs, &rhs)
} else {
super::syntax::stmt::format_let_in(&lhs, &rhs)
}
}).collect()
}
12 changes: 9 additions & 3 deletions src/lean/syntax.rs
Original file line number Diff line number Diff line change
@@ -291,7 +291,7 @@ pub(super) mod expr {
}

#[inline]
pub fn format_lambda(_captures: &str, args: &str, body: &str, ret_type: &str) -> String {
pub fn format_lambda(args: &str, body: &str, ret_type: &str) -> String {
format!("|{args}| -> {ret_type} {body}")
}
}
@@ -301,10 +301,16 @@ pub(super) mod stmt {
use super::*;

#[inline]
pub fn format_let_in(pat: &str, _binding_type: &str, bound_expr: &str) -> String {
format!("let {pat} = {bound_expr}")
pub fn format_let_in(lhs: &str, rhs: &str) -> String {
format!("let {lhs} = {rhs}")
}

#[inline]
pub fn format_let_mut_in(lhs: &str, rhs: &str) -> String {
format!("let mut {lhs} = {rhs}")
}


#[inline]
pub fn format_for_loop(loop_var: &str, loop_start: &str, loop_end: &str, body: &str) -> String {
formatdoc! {
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -203,6 +203,15 @@ mod test {
x
}
fn pattern_test() {
let opt = Option2::some(true);
let t = (1, opt, 3);
let (x, mut Option2 { _is_some, _value }, mut z) = t;
let lam = |(x, mut y, z) : (bool, bool, bool), k : Field| -> bool {
x
};
}
"#;

let source = Source::new(file_name, source);