Skip to content

Commit

Permalink
chore(experimental elaborator): Handle comptime expressions in the …
Browse files Browse the repository at this point in the history
…elaborator (noir-lang#5169)

# Description

## Problem\*

## Summary\*

Resolves some remaining `todo`s in the elaborator by calling the
interpreter on any comptime expressions within.

## Additional Context

The interpreter is mostly unchanged. The main exception being I've added
a new `Value::into_expression` function to convert into an `Expression`,
and renamed the old version to `Value::into_hir_expression` since it
creates an `ExprId` instead.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
jfecher and TomAFrench authored Jun 4, 2024
1 parent 888b94c commit 90632da
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 173 deletions.
19 changes: 18 additions & 1 deletion compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::ast::{
Ident, ItemVisibility, Path, Pattern, Recoverable, Statement, StatementKind,
UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, Visibility,
};
use crate::macros_api::StructId;
use crate::node_interner::ExprId;
use crate::token::{Attributes, Token};
use acvm::{acir::AcirField, FieldElement};
use iter_extended::vecmap;
Expand Down Expand Up @@ -33,6 +35,11 @@ pub enum ExpressionKind {
Parenthesized(Box<Expression>),
Quote(BlockExpression),
Comptime(BlockExpression),

// This variant is only emitted when inlining the result of comptime
// code. It is used to translate function values back into the AST while
// guaranteeing they have the same instantiated type and definition id without resolving again.
Resolved(ExprId),
Error,
}

Expand Down Expand Up @@ -108,7 +115,11 @@ impl ExpressionKind {
}

pub fn constructor((type_name, fields): (Path, Vec<(Ident, Expression)>)) -> ExpressionKind {
ExpressionKind::Constructor(Box::new(ConstructorExpression { type_name, fields }))
ExpressionKind::Constructor(Box::new(ConstructorExpression {
type_name,
fields,
struct_type: None,
}))
}

/// Returns true if the expression is a literal integer
Expand Down Expand Up @@ -451,6 +462,11 @@ pub struct MethodCallExpression {
pub struct ConstructorExpression {
pub type_name: Path,
pub fields: Vec<(Ident, Expression)>,

/// This may be filled out during macro expansion
/// so that we can skip re-resolving the type name since it
/// would be lost at that point.
pub struct_type: Option<StructId>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
Expand Down Expand Up @@ -522,6 +538,7 @@ impl Display for ExpressionKind {
Quote(block) => write!(f, "quote {block}"),
Comptime(block) => write!(f, "comptime {block}"),
Error => write!(f, "Error"),
Resolved(_) => write!(f, "?Resolved"),
}
}
}
Expand Down
96 changes: 67 additions & 29 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
UnresolvedTypeExpression,
},
hir::{
comptime::{self, Interpreter, InterpreterError},
resolution::{errors::ResolverError, resolver::LambdaContext},
type_check::TypeCheckError,
},
Expand Down Expand Up @@ -58,7 +59,10 @@ impl<'context> Elaborator<'context> {
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda),
ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr),
ExpressionKind::Quote(quote) => self.elaborate_quote(quote),
ExpressionKind::Comptime(comptime) => self.elaborate_comptime_block(comptime),
ExpressionKind::Comptime(comptime) => {
return self.elaborate_comptime_block(comptime, expr.span)
}
ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)),
ExpressionKind::Error => (HirExpression::Error, Type::Error),
};
let id = self.interner.push_expr(hir_expr);
Expand All @@ -68,6 +72,11 @@ impl<'context> Elaborator<'context> {
}

pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) {
let (block, typ) = self.elaborate_block_expression(block);
(HirExpression::Block(block), typ)
}

fn elaborate_block_expression(&mut self, block: BlockExpression) -> (HirBlockExpression, Type) {
self.push_scope();
let mut block_type = Type::Unit;
let mut statements = Vec::with_capacity(block.statements.len());
Expand All @@ -92,7 +101,7 @@ impl<'context> Elaborator<'context> {
}

self.pop_scope();
(HirExpression::Block(HirBlockExpression { statements }), block_type)
(HirBlockExpression { statements }, block_type)
}

fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) {
Expand Down Expand Up @@ -365,32 +374,34 @@ impl<'context> Elaborator<'context> {
) -> (HirExpression, Type) {
let span = constructor.type_name.span();

match self.lookup_type_or_error(constructor.type_name) {
Some(Type::Struct(r#type, struct_generics)) => {
let struct_type = r#type.clone();
let generics = struct_generics.clone();

let fields = constructor.fields;
let field_types = r#type.borrow().get_fields(&struct_generics);
let fields = self.resolve_constructor_expr_fields(
struct_type.clone(),
field_types,
fields,
span,
);
let expr = HirExpression::Constructor(HirConstructorExpression {
fields,
r#type,
struct_generics,
});
(expr, Type::Struct(struct_type, generics))
}
Some(typ) => {
self.push_err(ResolverError::NonStructUsedInConstructor { typ, span });
(HirExpression::Error, Type::Error)
let (r#type, struct_generics) = if let Some(struct_id) = constructor.struct_type {
let typ = self.interner.get_struct(struct_id);
let generics = typ.borrow().instantiate(self.interner);
(typ, generics)
} else {
match self.lookup_type_or_error(constructor.type_name) {
Some(Type::Struct(r#type, struct_generics)) => (r#type, struct_generics),
Some(typ) => {
self.push_err(ResolverError::NonStructUsedInConstructor { typ, span });
return (HirExpression::Error, Type::Error);
}
None => return (HirExpression::Error, Type::Error),
}
None => (HirExpression::Error, Type::Error),
}
};

let struct_type = r#type.clone();
let generics = struct_generics.clone();

let fields = constructor.fields;
let field_types = r#type.borrow().get_fields(&struct_generics);
let fields =
self.resolve_constructor_expr_fields(struct_type.clone(), field_types, fields, span);
let expr = HirExpression::Constructor(HirConstructorExpression {
fields,
r#type,
struct_generics,
});
(expr, Type::Struct(struct_type, generics))
}

/// Resolve all the fields of a struct constructor expression.
Expand Down Expand Up @@ -620,7 +631,34 @@ impl<'context> Elaborator<'context> {
(HirExpression::Quote(block), Type::Code)
}

fn elaborate_comptime_block(&mut self, _comptime: BlockExpression) -> (HirExpression, Type) {
todo!("Elaborate comptime block")
fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) {
let (block, _typ) = self.elaborate_block_expression(block);
let mut interpreter = Interpreter::new(self.interner, &mut self.comptime_scopes);
let value = interpreter.evaluate_block(block);
self.inline_comptime_value(value, span)
}

pub(super) fn inline_comptime_value(
&mut self,
value: Result<comptime::Value, InterpreterError>,
span: Span,
) -> (ExprId, Type) {
let make_error = |this: &mut Self, error: InterpreterError| {
this.errors.push(error.into_compilation_error_pair());
let error = this.interner.push_expr(HirExpression::Error);
this.interner.push_expr_location(error, span, this.file);
(error, Type::Error)
};

let value = match value {
Ok(value) => value,
Err(error) => return make_error(self, error),
};

let location = Location::new(span, self.file);
match value.into_expression(self.interner, location) {
Ok(new_expr) => self.elaborate_expression(new_expr),
Err(error) => make_error(self, error),
}
}
}
120 changes: 26 additions & 94 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
use crate::{
ast::{FunctionKind, UnresolvedTraitConstraint},
hir::{
comptime::{self, Interpreter},
def_collector::{
dc_crate::{
filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal,
Expand All @@ -21,7 +22,9 @@ use crate::{
macros_api::{
Ident, NodeInterner, NoirFunction, NoirStruct, Pattern, SecondaryAttribute, StructId,
},
node_interner::{DefinitionKind, DependencyId, ExprId, FuncId, TraitId, TypeAliasId},
node_interner::{
DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, TraitId, TypeAliasId,
},
Shared, Type, TypeVariable,
};
use crate::{
Expand Down Expand Up @@ -58,7 +61,7 @@ mod types;
use fm::FileId;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use rustc_hash::FxHashSet as HashSet;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

/// ResolverMetas are tagged onto each definition to track how many times they are used
#[derive(Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -151,6 +154,11 @@ pub struct Elaborator<'context> {
local_module: LocalModuleId,

crate_id: CrateId,

/// Each value currently in scope in the comptime interpreter.
/// Each element of the Vec represents a scope with every scope together making
/// up all currently visible definitions. The first scope is always the global scope.
comptime_scopes: Vec<HashMap<DefinitionId, comptime::Value>>,
}

impl<'context> Elaborator<'context> {
Expand All @@ -177,6 +185,7 @@ impl<'context> Elaborator<'context> {
type_variables: Vec::new(),
trait_constraints: Vec::new(),
current_trait_impl: None,
comptime_scopes: vec![HashMap::default()],
}
}

Expand Down Expand Up @@ -685,98 +694,6 @@ impl<'context> Elaborator<'context> {
}
}

fn find_numeric_generics(
parameters: &Parameters,
return_type: &Type,
) -> Vec<(String, TypeVariable)> {
let mut found = BTreeMap::new();
for (_, parameter, _) in &parameters.0 {
Self::find_numeric_generics_in_type(parameter, &mut found);
}
Self::find_numeric_generics_in_type(return_type, &mut found);
found.into_iter().collect()
}

fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap<String, TypeVariable>) {
match typ {
Type::FieldElement
| Type::Integer(_, _)
| Type::Bool
| Type::Unit
| Type::Error
| Type::TypeVariable(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Code
| Type::Forall(_, _) => (),

Type::TraitAsType(_, _, args) => {
for arg in args {
Self::find_numeric_generics_in_type(arg, found);
}
}

Type::Array(length, element_type) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
}
Self::find_numeric_generics_in_type(element_type, found);
}

Type::Slice(element_type) => {
Self::find_numeric_generics_in_type(element_type, found);
}

Type::Tuple(fields) => {
for field in fields {
Self::find_numeric_generics_in_type(field, found);
}
}

Type::Function(parameters, return_type, _env) => {
for parameter in parameters {
Self::find_numeric_generics_in_type(parameter, found);
}
Self::find_numeric_generics_in_type(return_type, found);
}

Type::Struct(struct_type, generics) => {
for (i, generic) in generics.iter().enumerate() {
if let Type::NamedGeneric(type_variable, name) = generic {
if struct_type.borrow().generic_is_numeric(i) {
found.insert(name.to_string(), type_variable.clone());
}
} else {
Self::find_numeric_generics_in_type(generic, found);
}
}
}
Type::Alias(alias, generics) => {
for (i, generic) in generics.iter().enumerate() {
if let Type::NamedGeneric(type_variable, name) = generic {
if alias.borrow().generic_is_numeric(i) {
found.insert(name.to_string(), type_variable.clone());
}
} else {
Self::find_numeric_generics_in_type(generic, found);
}
}
}
Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found),
Type::String(length) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
}
}
Type::FmtString(length, fields) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
}
Self::find_numeric_generics_in_type(fields, found);
}
}
}

fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
let object = constraint.typ.clone();
Expand Down Expand Up @@ -1190,8 +1107,23 @@ impl<'context> Elaborator<'context> {
self.push_err(ResolverError::MutableGlobal { span });
}

let comptime = let_stmt.comptime;

self.elaborate_global_let(let_stmt, global_id);

if comptime {
let let_statement = self
.interner
.get_global_let_statement(global_id)
.expect("Let statement of global should be set by elaborate_global_let");

let mut interpreter = Interpreter::new(self.interner, &mut self.comptime_scopes);

if let Err(error) = interpreter.evaluate_let(let_statement) {
self.errors.push(error.into_compilation_error_pair());
}
}

// Avoid defaulting the types of globals here since they may be used in any function.
// Otherwise we may prematurely default to a Field inside the next function if this
// global was unused there, even if it is consistently used as a u8 everywhere else.
Expand Down
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/elaborator/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ impl<'context> Elaborator<'context> {

pub fn push_scope(&mut self) {
self.scopes.start_scope();
self.comptime_scopes.push(Default::default());
}

pub fn pop_scope(&mut self) {
let scope = self.scopes.end_scope();
self.comptime_scopes.pop();
self.check_for_unused_variables_in_scope_tree(scope.into());
}

Expand Down
Loading

0 comments on commit 90632da

Please sign in to comment.