Skip to content

Commit

Permalink
Pattern support (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkn authored Jan 24, 2025
1 parent 3928b65 commit 4fe3c9b
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 83 deletions.
127 changes: 47 additions & 80 deletions src/lean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod builtin;
pub mod indent;
mod syntax;
mod pattern;

use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -1030,7 +1031,6 @@ impl LeanEmitter {
_ => panic!("member access lhs is not a struct or tuple: {lhs_expr_ty:?}"),
}
}

HirExpression::Cast(cast) => {
let source = self.emit_expr(ind, cast.lhs)?;
let target_type = self.emit_fully_qualified_type(&cast.r#type);
Expand Down Expand Up @@ -1066,50 +1066,62 @@ impl LeanEmitter {
HirExpression::Lambda(lambda) => {
let ret_type = self.emit_fully_qualified_type(&lambda.return_type);

let arg_strs: Vec<String> = lambda
// Divide the parameters into simple and complex parameters, where simple parameters are parameters that can be expressed as a single let or let mut binding.
let (simple_params, complex_params): (Vec<_>, Vec<_>) = lambda
.parameters
.iter()
.map(|(pattern, ty)| {
let pattern_str = self.emit_pattern(pattern)?;
let typ = self.emit_fully_qualified_type(ty);
Ok(format!("{pattern_str} : {typ}"))
})
.try_collect()?;
let args = arg_strs.join(", ");
let captures = lambda
.captures
.iter()
.map(|capture| {
let capture_type =
self.context.def_interner.definition_type(capture.ident.id);
let capture_type = self.emit_fully_qualified_type(&capture_type);
let name = self.context.def_interner.definition_name(capture.ident.id);

format!("{name} : {capture_type}")
.enumerate()
.partition_map(|(param_idx, (pat, param_typ))| {
let rhs = self.emit_fully_qualified_type(param_typ);
if let Some((_, lhs)) =
pattern::try_format_simple_pattern(pat, "", self)
{
// If the parameter is simple, we can directly use the ident as the lhs.
let lhs = self.context.def_interner.definition_name(lhs.id);
itertools::Either::Left(format!("{lhs} : {rhs}"))
} else {
// If the parameter is complex, we need to generate a fresh binding for it.
let lhs = format!("param#{param_idx}");
itertools::Either::Right((pat.clone(), lhs, rhs))
}
});
// Convert the parameters into strings.
let params_str = complex_params.iter()
.map(|(_, lhs, rhs)| {
format!("{lhs} : {rhs}")
})
.chain(simple_params)
.join(", ");

// Convert the complex parameters into a series of let (mut) bindings.
let pattern_stmts_str = complex_params.iter().map(|(pat, lhs, _)| {
pattern::format_pattern(pat, lhs, self).join(";\n")
}).join(";\n");
let body = self.emit_expr(ind, lambda.body)?;
// Prepend the body with the appropriate block of let (mut) bindings if there are any complex parameters.
let body = if pattern_stmts_str.is_empty() {
body
} else {
ind.run(format!("{{\n{pattern_stmts_str};\n{{\n{body}\n}}\n}}"))
};

syntax::expr::format_lambda(&captures, &args, &body, &ret_type)
syntax::expr::format_lambda(&params_str, &body, &ret_type)
}
HirExpression::MethodCall(_) => {
HirExpression::MethodCall(..) => {
panic!("Method call expressions should not exist after type checking")
}
HirExpression::Comptime(_) => {
HirExpression::Comptime(..) => {
panic!("Comptime expressions should not exist after compilation is done")
}
HirExpression::Quote(_) => {
HirExpression::Quote(..) => {
panic!("Quote expressions should not exist after macro resolution")
}
HirExpression::Unquote(_) => {
HirExpression::Unquote(..) => {
panic!("Unquote expressions should not exist after macro resolution")
}

HirExpression::Error => {
panic!("Encountered error expression where none should exist")
}
HirExpression::Unsafe(_) => panic!("unsafe expressions not supported yet"),
HirExpression::Unsafe(..) => panic!("Unsafe expressions not supported yet"),
};

Ok(expression)
Expand Down Expand Up @@ -1157,20 +1169,18 @@ impl LeanEmitter {
let result = match stmt_data {
HirStatement::Expression(expr) => self.emit_expr(ind, expr)?,
HirStatement::Let(lets) => {
let binding_type = self.emit_fully_qualified_type(&lets.r#type);
let bound_expr = self.emit_expr(ind, lets.expression)?;
let name = self.emit_pattern(&lets.pattern)?;
// [TODO] proper pattern support
syntax::stmt::format_let_in(&name, &binding_type, &bound_expr)
if let Some((simple_stmt, _)) = pattern::try_format_simple_pattern(&lets.pattern, &bound_expr, self) {
simple_stmt
} else {
let pat_rhs = "param#0";
let mut stmts = vec![syntax::stmt::format_let_in(pat_rhs, &bound_expr)];
stmts.extend(pattern::format_pattern(&lets.pattern, pat_rhs, self));
stmts.join(";\n")
}
}
HirStatement::Constrain(constraint) => {
let constraint_expr = self.emit_expr(ind, constraint.0)?;
// [TODO] what to do with asserts with prints?
let _print_expr = if let Some(expr) = constraint.2 {
Some(self.emit_expr(ind, expr)?)
} else {
None
};

syntax::expr::format_builtin_call(
builtin::ASSERT_BUILTIN_NAME.into(),
Expand Down Expand Up @@ -1271,49 +1281,6 @@ impl LeanEmitter {
Ok(result)
}

/// Emits the Lean code corresponding to a Noir pattern.
///
/// # Errors
///
/// - [`Error`] if the extraction process fails for any reason.
pub fn emit_pattern(&self, pattern: &HirPattern) -> Result<String> {
let result = match pattern {
HirPattern::Identifier(id) => {
self.context.def_interner.definition_name(id.id).to_string()
}
HirPattern::Mutable(pattern, _) => {
let child_pattern = self.emit_pattern(pattern.as_ref())?;
format!("mut {child_pattern}")
}
HirPattern::Tuple(patterns, _) => {
let pattern_strs: Vec<String> = patterns
.iter()
.map(|pattern| self.emit_pattern(pattern))
.try_collect()?;
let patterns_str = pattern_strs.join(", ");

format!("({patterns_str})")
}
HirPattern::Struct(struct_ty, patterns, _) => {
let ty_str = self.emit_fully_qualified_type(struct_ty);

let pattern_strs: Vec<String> = patterns
.iter()
.map(|(pat_name, pat_expr)| {
let child_pat = self.emit_pattern(pat_expr)?;

Ok(format!("{pat_name}: {child_pat}"))
})
.try_collect()?;
let patterns_str = pattern_strs.join(", ");

format!("{ty_str} {{{patterns_str}}}")
}
};

Ok(result)
}

/// Emits the Lean source code corresponding to a Noir literal.
///
/// # Errors
Expand Down
116 changes: 116 additions & 0 deletions src/lean/pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use noirc_frontend::{ast::Ident, hir_def::{expr::HirIdent, stmt::HirPattern}, Type};

#[derive(Clone, Debug)]
enum PatType {
Tuple(usize),
Struct { struct_type: Type, field: Ident },
}

/// The context of a pattern which contains all the necessary information to convert a nested `HirPattern::Identifier` pattern into a let (mut) binding.
/// This means, it contains the stack of tuple and struct patterns that the identifier is nested in, and whether the identifier is mutable.
#[derive(Clone, Debug)]
struct PatCtx {
stack: Vec<PatType>,
is_mut: bool,
}

impl Default for PatCtx {
fn default() -> Self {
Self { stack: Vec::new(), is_mut: false }
}
}

impl PatCtx {
fn push(&mut self, pat_type: PatType) {
self.stack.push(pat_type);
}

fn pop(&mut self) {
self.stack.pop();
}

fn set_mut(&mut self, is_mut: bool) {
self.is_mut = is_mut;
}
}

#[derive(Clone, Debug)]
struct PatRes(/** lhs **/ HirIdent, /** rhs **/ PatCtx);

fn parse_pattern(pat: &HirPattern, ctx: &mut PatCtx) -> Vec<PatRes> {
match pat {
HirPattern::Identifier(hir_ident) => {
vec![PatRes(hir_ident.clone(), ctx.clone())]
},
// A `mut` pattern makes the whole sub-pattern mutable.
// Note that nested mut patterns are unnecessary, and they are forbidden by the compiler.
HirPattern::Mutable(sub_pat, ..) => {
ctx.set_mut(true);
let res = parse_pattern(sub_pat, ctx);
ctx.set_mut(false);
res
},
HirPattern::Tuple(sub_pats, ..) => {
let mut res = Vec::new();
for (i, pat) in sub_pats.iter().enumerate() {
ctx.push(PatType::Tuple(i));
res.extend(parse_pattern(pat, ctx));
ctx.pop();
}
res
},
HirPattern::Struct(struct_type, sub_pats, ..) => {
let mut res = Vec::new();
for (ident, pat) in sub_pats.iter() {
ctx.push(PatType::Struct { struct_type: struct_type.clone(), field: ident.clone() });
res.extend(parse_pattern(pat, ctx));
ctx.pop();
}
res
},
}
}

/// Emits the Lean code corresponding to a Noir pattern as a single `let` or `let mut` binding, along with the `HirIdent` at the lhs of the pattern.
/// Returns `None` if the pattern is not simple enough to be expressed as a single binding.
pub(super) fn try_format_simple_pattern(pat: &HirPattern, pat_rhs: &str, emitter: &super::LeanEmitter) -> Option<(String, HirIdent)> {
match pat {
HirPattern::Identifier(ident) => {
format_pattern(pat, pat_rhs, emitter).pop().map(|pat| (pat, ident.clone()))
}
HirPattern::Mutable(sub_pat, ..) => {
if let HirPattern::Identifier(ident) = sub_pat.as_ref() {
format_pattern(pat, pat_rhs, emitter).pop().map(|pat| (pat, ident.clone()))
} else {
None
}
}
_ => None
}
}

/// Emits the Lean code corresponding to a Noir pattern as a series of `let` or `let mut` bindings.
pub(super) fn format_pattern(pat: &HirPattern, pat_rhs: &str, emitter: &super::LeanEmitter) -> Vec<String> {
let mut ctx = PatCtx::default();
parse_pattern(pat, &mut ctx).into_iter().map(|pat_res| {
let PatRes(id, ctx) = pat_res;
let lhs = emitter.context.def_interner.definition_name(id.id).to_string();
let mut rhs = pat_rhs.to_string();
for pat_type in ctx.stack {
match pat_type {
PatType::Tuple(i) => {
rhs = super::syntax::expr::format_tuple_access(&rhs, &format!("{}", i));
},
PatType::Struct { struct_type, field } => {
let struct_name = emitter.emit_fully_qualified_type(&struct_type);
rhs = super::syntax::expr::format_member_access(&struct_name, &rhs, &field.to_string());
},
}
}
if ctx.is_mut {
super::syntax::stmt::format_let_mut_in(&lhs, &rhs)
} else {
super::syntax::stmt::format_let_in(&lhs, &rhs)
}
}).collect()
}
12 changes: 9 additions & 3 deletions src/lean/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ pub(super) mod expr {
}

#[inline]
pub fn format_lambda(_captures: &str, args: &str, body: &str, ret_type: &str) -> String {
pub fn format_lambda(args: &str, body: &str, ret_type: &str) -> String {
format!("|{args}| -> {ret_type} {body}")
}
}
Expand All @@ -301,10 +301,16 @@ pub(super) mod stmt {
use super::*;

#[inline]
pub fn format_let_in(pat: &str, _binding_type: &str, bound_expr: &str) -> String {
format!("let {pat} = {bound_expr}")
pub fn format_let_in(lhs: &str, rhs: &str) -> String {
format!("let {lhs} = {rhs}")
}

#[inline]
pub fn format_let_mut_in(lhs: &str, rhs: &str) -> String {
format!("let mut {lhs} = {rhs}")
}


#[inline]
pub fn format_for_loop(loop_var: &str, loop_start: &str, loop_end: &str, body: &str) -> String {
formatdoc! {
Expand Down
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ mod test {
x
}
fn pattern_test() {
let opt = Option2::some(true);
let t = (1, opt, 3);
let (x, mut Option2 { _is_some, _value }, mut z) = t;
let lam = |(x, mut y, z) : (bool, bool, bool), k : Field| -> bool {
x
};
}
"#;

let source = Source::new(file_name, source);
Expand Down

0 comments on commit 4fe3c9b

Please sign in to comment.