From 90632dac6c493e4c7886503d71e682b611341e43 Mon Sep 17 00:00:00 2001 From: jfecher Date: Tue, 4 Jun 2024 16:47:05 +0100 Subject: [PATCH] chore(experimental elaborator): Handle `comptime` expressions in the elaborator (#5169) # Description ## Problem\* ## Summary\* Resolves some remaining `todo`s in the elaborator by calling the interpreter on any comptime expressions within. ## Additional Context The interpreter is mostly unchanged. The main exception being I've added a new `Value::into_expression` function to convert into an `Expression`, and renamed the old version to `Value::into_hir_expression` since it creates an `ExprId` instead. ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- compiler/noirc_frontend/src/ast/expression.rs | 19 ++- .../src/elaborator/expressions.rs | 96 ++++++++++---- compiler/noirc_frontend/src/elaborator/mod.rs | 120 ++++------------- .../noirc_frontend/src/elaborator/scope.rs | 2 + .../src/elaborator/statements.rs | 12 +- .../noirc_frontend/src/elaborator/types.rs | 96 +++++++++++++- .../noirc_frontend/src/hir/comptime/errors.rs | 6 + .../src/hir/comptime/interpreter.rs | 41 ++---- .../noirc_frontend/src/hir/comptime/scan.rs | 9 +- .../noirc_frontend/src/hir/comptime/tests.rs | 5 +- .../noirc_frontend/src/hir/comptime/value.rs | 125 ++++++++++++++++-- .../src/hir/def_collector/dc_crate.rs | 3 +- .../src/hir/resolution/resolver.rs | 3 + tooling/nargo_fmt/src/rewrite/expr.rs | 3 + 14 files changed, 367 insertions(+), 173 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 21131c71217..749e41d9c1c 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -5,6 +5,8 @@ use crate::ast::{ Ident, ItemVisibility, Path, Pattern, Recoverable, Statement, StatementKind, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, Visibility, }; +use crate::macros_api::StructId; +use crate::node_interner::ExprId; use crate::token::{Attributes, Token}; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -33,6 +35,11 @@ pub enum ExpressionKind { Parenthesized(Box), Quote(BlockExpression), Comptime(BlockExpression), + + // This variant is only emitted when inlining the result of comptime + // code. It is used to translate function values back into the AST while + // guaranteeing they have the same instantiated type and definition id without resolving again. + Resolved(ExprId), Error, } @@ -108,7 +115,11 @@ impl ExpressionKind { } pub fn constructor((type_name, fields): (Path, Vec<(Ident, Expression)>)) -> ExpressionKind { - ExpressionKind::Constructor(Box::new(ConstructorExpression { type_name, fields })) + ExpressionKind::Constructor(Box::new(ConstructorExpression { + type_name, + fields, + struct_type: None, + })) } /// Returns true if the expression is a literal integer @@ -451,6 +462,11 @@ pub struct MethodCallExpression { pub struct ConstructorExpression { pub type_name: Path, pub fields: Vec<(Ident, Expression)>, + + /// This may be filled out during macro expansion + /// so that we can skip re-resolving the type name since it + /// would be lost at that point. + pub struct_type: Option, } #[derive(Debug, PartialEq, Eq, Clone)] @@ -522,6 +538,7 @@ impl Display for ExpressionKind { Quote(block) => write!(f, "quote {block}"), Comptime(block) => write!(f, "comptime {block}"), Error => write!(f, "Error"), + Resolved(_) => write!(f, "?Resolved"), } } } diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 8acd1867074..abd8781a213 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -9,6 +9,7 @@ use crate::{ UnresolvedTypeExpression, }, hir::{ + comptime::{self, Interpreter, InterpreterError}, resolution::{errors::ResolverError, resolver::LambdaContext}, type_check::TypeCheckError, }, @@ -58,7 +59,10 @@ impl<'context> Elaborator<'context> { ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), ExpressionKind::Quote(quote) => self.elaborate_quote(quote), - ExpressionKind::Comptime(comptime) => self.elaborate_comptime_block(comptime), + ExpressionKind::Comptime(comptime) => { + return self.elaborate_comptime_block(comptime, expr.span) + } + ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)), ExpressionKind::Error => (HirExpression::Error, Type::Error), }; let id = self.interner.push_expr(hir_expr); @@ -68,6 +72,11 @@ impl<'context> Elaborator<'context> { } pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { + let (block, typ) = self.elaborate_block_expression(block); + (HirExpression::Block(block), typ) + } + + fn elaborate_block_expression(&mut self, block: BlockExpression) -> (HirBlockExpression, Type) { self.push_scope(); let mut block_type = Type::Unit; let mut statements = Vec::with_capacity(block.statements.len()); @@ -92,7 +101,7 @@ impl<'context> Elaborator<'context> { } self.pop_scope(); - (HirExpression::Block(HirBlockExpression { statements }), block_type) + (HirBlockExpression { statements }, block_type) } fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) { @@ -365,32 +374,34 @@ impl<'context> Elaborator<'context> { ) -> (HirExpression, Type) { let span = constructor.type_name.span(); - match self.lookup_type_or_error(constructor.type_name) { - Some(Type::Struct(r#type, struct_generics)) => { - let struct_type = r#type.clone(); - let generics = struct_generics.clone(); - - let fields = constructor.fields; - let field_types = r#type.borrow().get_fields(&struct_generics); - let fields = self.resolve_constructor_expr_fields( - struct_type.clone(), - field_types, - fields, - span, - ); - let expr = HirExpression::Constructor(HirConstructorExpression { - fields, - r#type, - struct_generics, - }); - (expr, Type::Struct(struct_type, generics)) - } - Some(typ) => { - self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); - (HirExpression::Error, Type::Error) + let (r#type, struct_generics) = if let Some(struct_id) = constructor.struct_type { + let typ = self.interner.get_struct(struct_id); + let generics = typ.borrow().instantiate(self.interner); + (typ, generics) + } else { + match self.lookup_type_or_error(constructor.type_name) { + Some(Type::Struct(r#type, struct_generics)) => (r#type, struct_generics), + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + return (HirExpression::Error, Type::Error); + } + None => return (HirExpression::Error, Type::Error), } - None => (HirExpression::Error, Type::Error), - } + }; + + let struct_type = r#type.clone(); + let generics = struct_generics.clone(); + + let fields = constructor.fields; + let field_types = r#type.borrow().get_fields(&struct_generics); + let fields = + self.resolve_constructor_expr_fields(struct_type.clone(), field_types, fields, span); + let expr = HirExpression::Constructor(HirConstructorExpression { + fields, + r#type, + struct_generics, + }); + (expr, Type::Struct(struct_type, generics)) } /// Resolve all the fields of a struct constructor expression. @@ -620,7 +631,34 @@ impl<'context> Elaborator<'context> { (HirExpression::Quote(block), Type::Code) } - fn elaborate_comptime_block(&mut self, _comptime: BlockExpression) -> (HirExpression, Type) { - todo!("Elaborate comptime block") + fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) { + let (block, _typ) = self.elaborate_block_expression(block); + let mut interpreter = Interpreter::new(self.interner, &mut self.comptime_scopes); + let value = interpreter.evaluate_block(block); + self.inline_comptime_value(value, span) + } + + pub(super) fn inline_comptime_value( + &mut self, + value: Result, + span: Span, + ) -> (ExprId, Type) { + let make_error = |this: &mut Self, error: InterpreterError| { + this.errors.push(error.into_compilation_error_pair()); + let error = this.interner.push_expr(HirExpression::Error); + this.interner.push_expr_location(error, span, this.file); + (error, Type::Error) + }; + + let value = match value { + Ok(value) => value, + Err(error) => return make_error(self, error), + }; + + let location = Location::new(span, self.file); + match value.into_expression(self.interner, location) { + Ok(new_expr) => self.elaborate_expression(new_expr), + Err(error) => make_error(self, error), + } } } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 24902c395b8..649a93b0bd6 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ ast::{FunctionKind, UnresolvedTraitConstraint}, hir::{ + comptime::{self, Interpreter}, def_collector::{ dc_crate::{ filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal, @@ -21,7 +22,9 @@ use crate::{ macros_api::{ Ident, NodeInterner, NoirFunction, NoirStruct, Pattern, SecondaryAttribute, StructId, }, - node_interner::{DefinitionKind, DependencyId, ExprId, FuncId, TraitId, TypeAliasId}, + node_interner::{ + DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, TraitId, TypeAliasId, + }, Shared, Type, TypeVariable, }; use crate::{ @@ -58,7 +61,7 @@ mod types; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span}; -use rustc_hash::FxHashSet as HashSet; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; /// ResolverMetas are tagged onto each definition to track how many times they are used #[derive(Debug, PartialEq, Eq)] @@ -151,6 +154,11 @@ pub struct Elaborator<'context> { local_module: LocalModuleId, crate_id: CrateId, + + /// Each value currently in scope in the comptime interpreter. + /// Each element of the Vec represents a scope with every scope together making + /// up all currently visible definitions. The first scope is always the global scope. + comptime_scopes: Vec>, } impl<'context> Elaborator<'context> { @@ -177,6 +185,7 @@ impl<'context> Elaborator<'context> { type_variables: Vec::new(), trait_constraints: Vec::new(), current_trait_impl: None, + comptime_scopes: vec![HashMap::default()], } } @@ -685,98 +694,6 @@ impl<'context> Elaborator<'context> { } } - fn find_numeric_generics( - parameters: &Parameters, - return_type: &Type, - ) -> Vec<(String, TypeVariable)> { - let mut found = BTreeMap::new(); - for (_, parameter, _) in ¶meters.0 { - Self::find_numeric_generics_in_type(parameter, &mut found); - } - Self::find_numeric_generics_in_type(return_type, &mut found); - found.into_iter().collect() - } - - fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { - match typ { - Type::FieldElement - | Type::Integer(_, _) - | Type::Bool - | Type::Unit - | Type::Error - | Type::TypeVariable(_, _) - | Type::Constant(_) - | Type::NamedGeneric(_, _) - | Type::Code - | Type::Forall(_, _) => (), - - Type::TraitAsType(_, _, args) => { - for arg in args { - Self::find_numeric_generics_in_type(arg, found); - } - } - - Type::Array(length, element_type) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { - found.insert(name.to_string(), type_variable.clone()); - } - Self::find_numeric_generics_in_type(element_type, found); - } - - Type::Slice(element_type) => { - Self::find_numeric_generics_in_type(element_type, found); - } - - Type::Tuple(fields) => { - for field in fields { - Self::find_numeric_generics_in_type(field, found); - } - } - - Type::Function(parameters, return_type, _env) => { - for parameter in parameters { - Self::find_numeric_generics_in_type(parameter, found); - } - Self::find_numeric_generics_in_type(return_type, found); - } - - Type::Struct(struct_type, generics) => { - for (i, generic) in generics.iter().enumerate() { - if let Type::NamedGeneric(type_variable, name) = generic { - if struct_type.borrow().generic_is_numeric(i) { - found.insert(name.to_string(), type_variable.clone()); - } - } else { - Self::find_numeric_generics_in_type(generic, found); - } - } - } - Type::Alias(alias, generics) => { - for (i, generic) in generics.iter().enumerate() { - if let Type::NamedGeneric(type_variable, name) = generic { - if alias.borrow().generic_is_numeric(i) { - found.insert(name.to_string(), type_variable.clone()); - } - } else { - Self::find_numeric_generics_in_type(generic, found); - } - } - } - Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), - Type::String(length) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { - found.insert(name.to_string(), type_variable.clone()); - } - } - Type::FmtString(length, fields) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { - found.insert(name.to_string(), type_variable.clone()); - } - Self::find_numeric_generics_in_type(fields, found); - } - } - } - fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) { for constraint in &func_meta.trait_constraints { let object = constraint.typ.clone(); @@ -1190,8 +1107,23 @@ impl<'context> Elaborator<'context> { self.push_err(ResolverError::MutableGlobal { span }); } + let comptime = let_stmt.comptime; + self.elaborate_global_let(let_stmt, global_id); + if comptime { + let let_statement = self + .interner + .get_global_let_statement(global_id) + .expect("Let statement of global should be set by elaborate_global_let"); + + let mut interpreter = Interpreter::new(self.interner, &mut self.comptime_scopes); + + if let Err(error) = interpreter.evaluate_let(let_statement) { + self.errors.push(error.into_compilation_error_pair()); + } + } + // Avoid defaulting the types of globals here since they may be used in any function. // Otherwise we may prematurely default to a Field inside the next function if this // global was unused there, even if it is consistently used as a u8 everywhere else. diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index 6ae43bd3c49..9fd3be0a354 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -115,10 +115,12 @@ impl<'context> Elaborator<'context> { pub fn push_scope(&mut self) { self.scopes.start_scope(); + self.comptime_scopes.push(Default::default()); } pub fn pop_scope(&mut self) { let scope = self.scopes.end_scope(); + self.comptime_scopes.pop(); self.check_for_unused_variables_in_scope_tree(scope.into()); } diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 8f2f4d3911a..5bcd43da6e5 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -3,6 +3,7 @@ use noirc_errors::{Location, Span}; use crate::{ ast::{AssignStatement, ConstrainStatement, LValue}, hir::{ + comptime::Interpreter, resolution::errors::ResolverError, type_check::{Source, TypeCheckError}, }, @@ -30,7 +31,7 @@ impl<'context> Elaborator<'context> { StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), StatementKind::Break => self.elaborate_jump(true, statement.span), StatementKind::Continue => self.elaborate_jump(false, statement.span), - StatementKind::Comptime(statement) => self.elaborate_comptime(*statement), + StatementKind::Comptime(statement) => self.elaborate_comptime_statement(*statement), StatementKind::Expression(expr) => { let (expr, typ) = self.elaborate_expression(expr); (HirStatement::Expression(expr), typ) @@ -438,7 +439,12 @@ impl<'context> Elaborator<'context> { None } - pub(super) fn elaborate_comptime(&self, _statement: Statement) -> (HirStatement, Type) { - todo!("Comptime scanning") + fn elaborate_comptime_statement(&mut self, statement: Statement) -> (HirStatement, Type) { + let span = statement.span; + let (hir_statement, _typ) = self.elaborate_statement(statement); + let mut interpreter = Interpreter::new(self.interner, &mut self.comptime_scopes); + let value = interpreter.evaluate_statement(hir_statement); + let (expr, typ) = self.inline_comptime_value(value, span); + (HirStatement::Expression(expr), typ) } } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 3d677d8740d..d4cbdac3507 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{collections::BTreeMap, rc::Rc}; use acvm::acir::AcirField; use iter_extended::vecmap; @@ -22,7 +22,7 @@ use crate::{ HirBinaryOp, HirCallExpression, HirIdent, HirMemberAccess, HirMethodReference, HirPrefixExpression, }, - function::FuncMeta, + function::{FuncMeta, Parameters}, traits::TraitConstraint, }, macros_api::{ @@ -1362,4 +1362,96 @@ impl<'context> Elaborator<'context> { self.generics.push((rc_name, typevar, span)); } } + + pub fn find_numeric_generics( + parameters: &Parameters, + return_type: &Type, + ) -> Vec<(String, TypeVariable)> { + let mut found = BTreeMap::new(); + for (_, parameter, _) in ¶meters.0 { + Self::find_numeric_generics_in_type(parameter, &mut found); + } + Self::find_numeric_generics_in_type(return_type, &mut found); + found.into_iter().collect() + } + + fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { + match typ { + Type::FieldElement + | Type::Integer(_, _) + | Type::Bool + | Type::Unit + | Type::Error + | Type::TypeVariable(_, _) + | Type::Constant(_) + | Type::NamedGeneric(_, _) + | Type::Code + | Type::Forall(_, _) => (), + + Type::TraitAsType(_, _, args) => { + for arg in args { + Self::find_numeric_generics_in_type(arg, found); + } + } + + Type::Array(length, element_type) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Slice(element_type) => { + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Tuple(fields) => { + for field in fields { + Self::find_numeric_generics_in_type(field, found); + } + } + + Type::Function(parameters, return_type, _env) => { + for parameter in parameters { + Self::find_numeric_generics_in_type(parameter, found); + } + Self::find_numeric_generics_in_type(return_type, found); + } + + Type::Struct(struct_type, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if struct_type.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::Alias(alias, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if alias.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), + Type::String(length) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + } + Type::FmtString(length, fields) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(fields, found); + } + } + } } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index 34cecf0ece4..20e3fd94b7d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -52,6 +52,12 @@ pub enum InterpreterError { #[allow(unused)] pub(super) type IResult = std::result::Result; +impl From for CompilationError { + fn from(error: InterpreterError) -> Self { + CompilationError::InterpreterError(error) + } +} + impl InterpreterError { pub fn into_compilation_error_pair(self) -> (CompilationError, fm::FileId) { let location = self.get_location(); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index c0aeb910f22..82e7d70141d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -4,7 +4,7 @@ use acvm::{acir::AcirField, FieldElement}; use im::Vector; use iter_extended::try_vecmap; use noirc_errors::Location; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use rustc_hash::FxHashMap as HashMap; use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; use crate::{ @@ -36,34 +36,18 @@ pub struct Interpreter<'interner> { /// Each value currently in scope in the interpreter. /// Each element of the Vec represents a scope with every scope together making /// up all currently visible definitions. - scopes: Vec>, - - /// True if we've expanded any macros into any functions and will need - /// to redo name resolution & type checking for that function. - changed_functions: HashSet, - - /// True if we've expanded any macros into global scope and will need - /// to redo name resolution & type checking for everything. - changed_globally: bool, + scopes: &'interner mut Vec>, in_loop: bool, - - /// True if we're currently in a compile-time context. - /// If this is false code is skipped over instead of executed. - in_comptime_context: bool, } #[allow(unused)] impl<'a> Interpreter<'a> { - pub(crate) fn new(interner: &'a mut NodeInterner) -> Self { - Self { - interner, - scopes: vec![HashMap::default()], - changed_functions: HashSet::default(), - changed_globally: false, - in_loop: false, - in_comptime_context: false, - } + pub(crate) fn new( + interner: &'a mut NodeInterner, + scopes: &'a mut Vec>, + ) -> Self { + Self { interner, scopes, in_loop: false } } pub(crate) fn call_function( @@ -468,7 +452,7 @@ impl<'a> Interpreter<'a> { } } - pub(super) fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { + pub fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { let last_statement = block.statements.pop(); self.push_scope(); @@ -1077,7 +1061,7 @@ impl<'a> Interpreter<'a> { Ok(Value::Closure(lambda, environment, typ)) } - fn evaluate_statement(&mut self, statement: StmtId) -> IResult { + pub fn evaluate_statement(&mut self, statement: StmtId) -> IResult { match self.interner.statement(&statement) { HirStatement::Let(let_) => self.evaluate_let(let_), HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), @@ -1098,7 +1082,7 @@ impl<'a> Interpreter<'a> { } } - pub(super) fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult { + pub fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult { let rhs = self.evaluate(let_.expression)?; let location = self.interner.expr_location(&let_.expression); self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?; @@ -1265,9 +1249,6 @@ impl<'a> Interpreter<'a> { } pub(super) fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { - let was_in_comptime = std::mem::replace(&mut self.in_comptime_context, true); - let result = self.evaluate_statement(statement); - self.in_comptime_context = was_in_comptime; - result + self.evaluate_statement(statement) } } diff --git a/compiler/noirc_frontend/src/hir/comptime/scan.rs b/compiler/noirc_frontend/src/hir/comptime/scan.rs index cc6b9aa7e9c..02010b6886d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/scan.rs +++ b/compiler/noirc_frontend/src/hir/comptime/scan.rs @@ -82,7 +82,7 @@ impl<'interner> Interpreter<'interner> { HirExpression::Comptime(block) => { let location = self.interner.expr_location(&expr); let new_expr = - self.evaluate_block(block)?.into_expression(self.interner, location)?; + self.evaluate_block(block)?.into_hir_expression(self.interner, location)?; let new_expr = self.interner.expression(&new_expr); self.interner.replace_expr(&expr, new_expr); Ok(()) @@ -229,8 +229,9 @@ impl<'interner> Interpreter<'interner> { HirStatement::Error => Ok(()), HirStatement::Comptime(comptime) => { let location = self.interner.statement_location(comptime); - let new_expr = - self.evaluate_comptime(comptime)?.into_expression(self.interner, location)?; + let new_expr = self + .evaluate_comptime(comptime)? + .into_hir_expression(self.interner, location)?; self.interner.replace_statement(statement, HirStatement::Expression(new_expr)); Ok(()) } @@ -249,7 +250,7 @@ impl<'interner> Interpreter<'interner> { fn inline_expression(&mut self, value: Value, expr: ExprId) -> IResult<()> { let location = self.interner.expr_location(&expr); - let new_expr = value.into_expression(self.interner, location)?; + let new_expr = value.into_hir_expression(self.interner, location)?; let new_expr = self.interner.expression(&new_expr); self.interner.replace_expr(&expr, new_expr); Ok(()) diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index 41475d3ccf4..43f6e21905b 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -1,5 +1,7 @@ #![cfg(test)] +use std::collections::HashMap; + use noirc_errors::Location; use super::errors::InterpreterError; @@ -9,7 +11,8 @@ use crate::hir::type_check::test::type_check_src_code; fn interpret_helper(src: &str, func_namespace: Vec) -> Result { let (mut interner, main_id) = type_check_src_code(src, func_namespace); - let mut interpreter = Interpreter::new(&mut interner); + let mut scopes = vec![HashMap::default()]; + let mut interpreter = Interpreter::new(&mut interner, &mut scopes); let no_location = Location::dummy(); interpreter.call_function(main_id, Vec::new(), no_location) diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 3c8b6e92445..f2ff93b8929 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -6,9 +6,13 @@ use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use crate::{ - ast::{BlockExpression, Ident, IntegerBitSize, Signedness}, + ast::{ + ArrayLiteral, BlockExpression, ConstructorExpression, Ident, IntegerBitSize, Signedness, + }, hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind}, - macros_api::{HirExpression, HirLiteral, NodeInterner}, + macros_api::{ + Expression, ExpressionKind, HirExpression, HirLiteral, Literal, NodeInterner, Path, + }, node_interner::{ExprId, FuncId}, Shared, Type, }; @@ -78,6 +82,108 @@ impl Value { self, interner: &mut NodeInterner, location: Location, + ) -> IResult { + let kind = match self { + Value::Unit => ExpressionKind::Literal(Literal::Unit), + Value::Bool(value) => ExpressionKind::Literal(Literal::Bool(value)), + Value::Field(value) => ExpressionKind::Literal(Literal::Integer(value, false)), + Value::I8(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + ExpressionKind::Literal(Literal::Integer(value, negative)) + } + Value::I16(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + ExpressionKind::Literal(Literal::Integer(value, negative)) + } + Value::I32(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + ExpressionKind::Literal(Literal::Integer(value, negative)) + } + Value::I64(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + ExpressionKind::Literal(Literal::Integer(value, negative)) + } + Value::U8(value) => { + ExpressionKind::Literal(Literal::Integer((value as u128).into(), false)) + } + Value::U16(value) => { + ExpressionKind::Literal(Literal::Integer((value as u128).into(), false)) + } + Value::U32(value) => { + ExpressionKind::Literal(Literal::Integer((value as u128).into(), false)) + } + Value::U64(value) => { + ExpressionKind::Literal(Literal::Integer((value as u128).into(), false)) + } + Value::String(value) => ExpressionKind::Literal(Literal::Str(unwrap_rc(value))), + Value::Function(id, typ) => { + let id = interner.function_definition_id(id); + let impl_kind = ImplKind::NotATraitMethod; + let ident = HirIdent { location, id, impl_kind }; + let expr_id = interner.push_expr(HirExpression::Ident(ident, None)); + interner.push_expr_location(expr_id, location.span, location.file); + interner.push_expr_type(expr_id, typ); + ExpressionKind::Resolved(expr_id) + } + Value::Closure(_lambda, _env, _typ) => { + // TODO: How should a closure's environment be inlined? + let item = "Returning closures from a comptime fn"; + return Err(InterpreterError::Unimplemented { item, location }); + } + Value::Tuple(fields) => { + let fields = try_vecmap(fields, |field| field.into_expression(interner, location))?; + ExpressionKind::Tuple(fields) + } + Value::Struct(fields, typ) => { + let fields = try_vecmap(fields, |(name, field)| { + let field = field.into_expression(interner, location)?; + Ok((Ident::new(unwrap_rc(name), location.span), field)) + })?; + + let struct_type = match typ.follow_bindings() { + Type::Struct(def, _) => Some(def.borrow().id), + _ => return Err(InterpreterError::NonStructInConstructor { typ, location }), + }; + + // Since we've provided the struct_type, the path should be ignored. + let type_name = Path::from_single(String::new(), location.span); + ExpressionKind::Constructor(Box::new(ConstructorExpression { + type_name, + fields, + struct_type, + })) + } + Value::Array(elements, _) => { + let elements = + try_vecmap(elements, |element| element.into_expression(interner, location))?; + ExpressionKind::Literal(Literal::Array(ArrayLiteral::Standard(elements))) + } + Value::Slice(elements, _) => { + let elements = + try_vecmap(elements, |element| element.into_expression(interner, location))?; + ExpressionKind::Literal(Literal::Slice(ArrayLiteral::Standard(elements))) + } + Value::Code(block) => ExpressionKind::Block(unwrap_rc(block)), + Value::Pointer(_) => { + return Err(InterpreterError::CannotInlineMacro { value: self, location }) + } + }; + + Ok(Expression::new(kind, location.span)) + } + + pub(crate) fn into_hir_expression( + self, + interner: &mut NodeInterner, + location: Location, ) -> IResult { let typ = self.get_type().into_owned(); @@ -133,12 +239,13 @@ impl Value { return Err(InterpreterError::Unimplemented { item, location }); } Value::Tuple(fields) => { - let fields = try_vecmap(fields, |field| field.into_expression(interner, location))?; + let fields = + try_vecmap(fields, |field| field.into_hir_expression(interner, location))?; HirExpression::Tuple(fields) } Value::Struct(fields, typ) => { let fields = try_vecmap(fields, |(name, field)| { - let field = field.into_expression(interner, location)?; + let field = field.into_hir_expression(interner, location)?; Ok((Ident::new(unwrap_rc(name), location.span), field)) })?; @@ -154,13 +261,15 @@ impl Value { }) } Value::Array(elements, _) => { - let elements = - try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + let elements = try_vecmap(elements, |element| { + element.into_hir_expression(interner, location) + })?; HirExpression::Literal(HirLiteral::Array(HirArrayLiteral::Standard(elements))) } Value::Slice(elements, _) => { - let elements = - try_vecmap(elements, |elements| elements.into_expression(interner, location))?; + let elements = try_vecmap(elements, |element| { + element.into_hir_expression(interner, location) + })?; HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) } Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)), diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 05147af5459..096bab2b47e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -548,7 +548,8 @@ impl ResolvedModule { /// Evaluate all `comptime` expressions in this module fn evaluate_comptime(&mut self, interner: &mut NodeInterner) { if self.count_errors() == 0 { - let mut interpreter = Interpreter::new(interner); + let mut scopes = vec![HashMap::default()]; + let mut interpreter = Interpreter::new(interner, &mut scopes); for (_file, global) in &self.globals { if let Err(error) = interpreter.scan_global(*global) { diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 35ba964c499..8f5e99bacb9 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -1637,6 +1637,9 @@ impl<'a> Resolver<'a> { // The quoted expression isn't resolved since we don't want errors if variables aren't defined ExpressionKind::Quote(block) => HirExpression::Quote(block), ExpressionKind::Comptime(block) => HirExpression::Comptime(self.resolve_block(block)), + ExpressionKind::Resolved(_) => unreachable!( + "ExpressionKind::Resolved should only be emitted by the comptime interpreter" + ), }; // If these lines are ever changed, make sure to change the early return diff --git a/tooling/nargo_fmt/src/rewrite/expr.rs b/tooling/nargo_fmt/src/rewrite/expr.rs index e5b30f99b7b..9a704717ade 100644 --- a/tooling/nargo_fmt/src/rewrite/expr.rs +++ b/tooling/nargo_fmt/src/rewrite/expr.rs @@ -171,6 +171,9 @@ pub(crate) fn rewrite( format!("comptime {}", rewrite_block(visitor, block, span)) } ExpressionKind::Error => unreachable!(), + ExpressionKind::Resolved(_) => { + unreachable!("ExpressionKind::Resolved should only emitted by the comptime interpreter") + } } }