diff --git a/Cargo.lock b/Cargo.lock index f08467e0172..f71e0ceabbb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3200,6 +3200,7 @@ dependencies = [ "bn254_blackbox_solver", "cfg-if 1.0.0", "fm", + "fxhash", "im", "iter-extended", "noirc_arena", diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 9377cadb260..78724261f0d 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -157,7 +157,6 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Self { - RuntimeSeparatorContext::separate_runtime(&mut self); - self - } -} - -#[derive(Debug, Default)] -struct RuntimeSeparatorContext { - // Original functions to clone to brillig - acir_functions_called_from_brillig: BTreeSet, - // Tracks the original => cloned version - mapped_functions: HashMap, -} - -impl RuntimeSeparatorContext { - pub(crate) fn separate_runtime(ssa: &mut Ssa) { - let mut runtime_separator = RuntimeSeparatorContext::default(); - - // We first collect all the acir functions called from a brillig context by exploring the SSA recursively - let mut processed_functions = HashSet::default(); - runtime_separator.collect_acir_functions_called_from_brillig( - ssa, - ssa.main_id, - false, - &mut processed_functions, - ); - - // Now we clone the relevant acir functions and change their runtime to brillig - runtime_separator.convert_acir_functions_called_from_brillig_to_brillig(ssa); - - // Now we update any calls within a brillig context to the mapped functions - runtime_separator.replace_calls_to_mapped_functions(ssa); - - // Some functions might be unreachable now (for example an acir function only called from brillig) - prune_unreachable_functions(ssa); - } - - fn collect_acir_functions_called_from_brillig( - &mut self, - ssa: &Ssa, - current_func_id: FunctionId, - mut within_brillig: bool, - processed_functions: &mut HashSet<(/* within_brillig */ bool, FunctionId)>, - ) { - // Processed functions needs the within brillig flag, since it is possible to call the same function from both brillig and acir - if processed_functions.contains(&(within_brillig, current_func_id)) { - return; - } - processed_functions.insert((within_brillig, current_func_id)); - - let func = &ssa.functions[¤t_func_id]; - if matches!(func.runtime(), RuntimeType::Brillig(_)) { - within_brillig = true; - } - - let called_functions = called_functions(func); - - if within_brillig { - for called_func_id in called_functions.iter() { - let called_func = &ssa.functions[called_func_id]; - if matches!(called_func.runtime(), RuntimeType::Acir(_)) { - self.acir_functions_called_from_brillig.insert(*called_func_id); - } - } - } - - for called_func_id in called_functions.into_iter() { - self.collect_acir_functions_called_from_brillig( - ssa, - called_func_id, - within_brillig, - processed_functions, - ); - } - } - - fn convert_acir_functions_called_from_brillig_to_brillig(&mut self, ssa: &mut Ssa) { - for acir_func_id in self.acir_functions_called_from_brillig.iter() { - let RuntimeType::Acir(inline_type) = ssa.functions[acir_func_id].runtime() else { - unreachable!("Function to transform to brillig should be ACIR") - }; - let cloned_id = ssa.clone_fn(*acir_func_id); - let new_func = - ssa.functions.get_mut(&cloned_id).expect("Cloned function should exist in SSA"); - new_func.set_runtime(RuntimeType::Brillig(inline_type)); - self.mapped_functions.insert(*acir_func_id, cloned_id); - } - } - - fn replace_calls_to_mapped_functions(&self, ssa: &mut Ssa) { - for (_function_id, func) in ssa.functions.iter_mut() { - if matches!(func.runtime(), RuntimeType::Brillig(_)) { - for called_func_value_id in called_functions_values(func).iter() { - let Value::Function(called_func_id) = &func.dfg[*called_func_value_id] else { - unreachable!("Value should be a function") - }; - if let Some(mapped_func_id) = self.mapped_functions.get(called_func_id) { - let mapped_value_id = func.dfg.import_function(*mapped_func_id); - func.dfg.set_value_from_id(*called_func_value_id, mapped_value_id); - } - } - } - } - } -} - -// We only consider direct calls to functions since functions as values should have been resolved -fn called_functions_values(func: &Function) -> BTreeSet { - let mut called_function_ids = BTreeSet::default(); - for block_id in func.reachable_blocks() { - for instruction_id in func.dfg[block_id].instructions() { - let Instruction::Call { func: called_value_id, .. } = &func.dfg[*instruction_id] else { - continue; - }; - - if let Value::Function(_) = func.dfg[*called_value_id] { - called_function_ids.insert(*called_value_id); - } - } - } - - called_function_ids -} - -fn called_functions(func: &Function) -> BTreeSet { - called_functions_values(func) - .into_iter() - .map(|value_id| { - let Value::Function(func_id) = func.dfg[value_id] else { - unreachable!("Value should be a function") - }; - func_id - }) - .collect() -} - -fn collect_reachable_functions( - ssa: &Ssa, - current_func_id: FunctionId, - reachable_functions: &mut HashSet, -) { - if reachable_functions.contains(¤t_func_id) { - return; - } - reachable_functions.insert(current_func_id); - - let func = &ssa.functions[¤t_func_id]; - let called_functions = called_functions(func); - - for called_func_id in called_functions.iter() { - collect_reachable_functions(ssa, *called_func_id, reachable_functions); - } -} - -fn prune_unreachable_functions(ssa: &mut Ssa) { - let mut reachable_functions = HashSet::default(); - collect_reachable_functions(ssa, ssa.main_id, &mut reachable_functions); - - ssa.functions.retain(|id, _value| reachable_functions.contains(id)); -} - -#[cfg(test)] -mod test { - use std::collections::BTreeSet; - - use noirc_frontend::monomorphization::ast::InlineType; - - use crate::ssa::{ - function_builder::FunctionBuilder, - ir::{ - function::{Function, FunctionId, RuntimeType}, - map::Id, - types::Type, - }, - opt::runtime_separation::called_functions, - ssa_gen::Ssa, - }; - - #[test] - fn basic_runtime_separation() { - // brillig fn foo { - // b0(): - // v0 = call bar() - // return v0 - // } - // acir fn bar { - // b0(): - // return 72 - // } - let foo_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("foo".into(), foo_id); - builder.current_function.set_runtime(RuntimeType::Brillig(InlineType::default())); - - let bar_id = Id::test_new(1); - let bar = builder.import_function(bar_id); - let results = builder.insert_call(bar, Vec::new(), vec![Type::field()]).to_vec(); - builder.terminate_with_return(results); - - builder.new_function("bar".into(), bar_id, InlineType::default()); - let expected_return = 72u128; - let seventy_two = builder.field_constant(expected_return); - builder.terminate_with_return(vec![seventy_two]); - - let ssa = builder.finish(); - assert_eq!(ssa.functions.len(), 2); - - // Expected result - // brillig fn foo { - // b0(): - // v0 = call bar() - // return v0 - // } - // brillig fn bar { - // b0(): - // return 72 - // } - let separated = ssa.separate_runtime(); - - // The original bar function must have been pruned - assert_eq!(separated.functions.len(), 2); - - // All functions should be brillig now - for func in separated.functions.values() { - assert_eq!(func.runtime(), RuntimeType::Brillig(InlineType::default())); - } - } - - fn find_func_by_name<'ssa>( - ssa: &'ssa Ssa, - funcs: &BTreeSet, - name: &str, - ) -> &'ssa Function { - funcs - .iter() - .find_map(|id| { - let func = ssa.functions.get(id).unwrap(); - if func.name() == name { - Some(func) - } else { - None - } - }) - .unwrap() - } - - #[test] - fn same_function_shared_acir_brillig() { - // acir fn foo { - // b0(): - // v0 = call bar() - // v1 = call baz() - // return v0, v1 - // } - // brillig fn bar { - // b0(): - // v0 = call baz() - // return v0 - // } - // acir fn baz { - // b0(): - // return 72 - // } - let foo_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("foo".into(), foo_id); - - let bar_id = Id::test_new(1); - let baz_id = Id::test_new(2); - let bar = builder.import_function(bar_id); - let baz = builder.import_function(baz_id); - let v0 = builder.insert_call(bar, Vec::new(), vec![Type::field()]).to_vec(); - let v1 = builder.insert_call(baz, Vec::new(), vec![Type::field()]).to_vec(); - builder.terminate_with_return(vec![v0[0], v1[0]]); - - builder.new_brillig_function("bar".into(), bar_id, InlineType::default()); - let baz = builder.import_function(baz_id); - let v0 = builder.insert_call(baz, Vec::new(), vec![Type::field()]).to_vec(); - builder.terminate_with_return(v0); - - builder.new_function("baz".into(), baz_id, InlineType::default()); - let expected_return = 72u128; - let seventy_two = builder.field_constant(expected_return); - builder.terminate_with_return(vec![seventy_two]); - - let ssa = builder.finish(); - assert_eq!(ssa.functions.len(), 3); - - // Expected result - // acir fn foo { - // b0(): - // v0 = call bar() - // v1 = call baz() <- baz_acir - // return v0, v1 - // } - // brillig fn bar { - // b0(): - // v0 = call baz() <- baz_brillig - // return v0 - // } - // acir fn baz { - // b0(): - // return 72 - // } - // brillig fn baz { - // b0(): - // return 72 - // } - let separated = ssa.separate_runtime(); - - // The original baz function must have been duplicated - assert_eq!(separated.functions.len(), 4); - - let main_function = separated.functions.get(&separated.main_id).unwrap(); - assert_eq!(main_function.runtime(), RuntimeType::Acir(InlineType::Inline)); - - let main_calls = called_functions(main_function); - assert_eq!(main_calls.len(), 2); - - let bar = find_func_by_name(&separated, &main_calls, "bar"); - let baz_acir = find_func_by_name(&separated, &main_calls, "baz"); - - assert_eq!(baz_acir.runtime(), RuntimeType::Acir(InlineType::Inline)); - assert_eq!(bar.runtime(), RuntimeType::Brillig(InlineType::default())); - - let bar_calls = called_functions(bar); - assert_eq!(bar_calls.len(), 1); - - let baz_brillig = find_func_by_name(&separated, &bar_calls, "baz"); - assert_eq!(baz_brillig.runtime(), RuntimeType::Brillig(InlineType::default())); - } -} diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index 5f8f02689c8..041c1b1e015 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -31,6 +31,7 @@ petgraph = "0.6" rangemap = "1.4.0" strum.workspace = true strum_macros.workspace = true +fxhash.workspace = true [dev-dependencies] diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 8c07d71de21..d09257b0634 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -25,11 +25,12 @@ use crate::{ Kind, Type, TypeBinding, TypeBindings, }; use acvm::{acir::AcirField, FieldElement}; +use fxhash::FxHashMap as HashMap; use iter_extended::{btree_map, try_vecmap, vecmap}; use noirc_errors::Location; use noirc_printable_type::PrintableType; use std::{ - collections::{BTreeMap, HashMap, VecDeque}, + collections::{BTreeMap, VecDeque}, unreachable, }; @@ -56,14 +57,12 @@ struct LambdaContext { /// This struct holds the FIFO queue of functions to monomorphize, which is added to /// whenever a new (function, type) combination is encountered. struct Monomorphizer<'interner> { - /// Functions are keyed by their unique ID and expected type so that we can monomorphize - /// a new version of the function for each type. - /// We also key by any turbofish generics that are specified. - /// This is necessary for a case where we may have a trait generic that can be instantiated - /// outside of a function parameter or return value. + /// Functions are keyed by their unique ID, whether they're unconstrained, their expected type, + /// and any generics they have so that we can monomorphize a new version of the function for each type. /// - /// Using nested HashMaps here lets us avoid cloning HirTypes when calling .get() - functions: HashMap), FuncId>>, + /// Keying by any turbofish generics that are specified is necessary for a case where we may have a + /// trait generic that can be instantiated outside of a function parameter or return value. + functions: Functions, /// Unlike functions, locals are only keyed by their unique ID because they are never /// duplicated during monomorphization. Doing so would allow them to be used polymorphically @@ -72,8 +71,15 @@ struct Monomorphizer<'interner> { locals: HashMap, /// Queue of functions to monomorphize next each item in the queue is a tuple of: - /// (old_id, new_monomorphized_id, any type bindings to apply, the trait method if old_id is from a trait impl) - queue: VecDeque<(node_interner::FuncId, FuncId, TypeBindings, Option, Location)>, + /// (old_id, new_monomorphized_id, any type bindings to apply, the trait method if old_id is from a trait impl, is_unconstrained, location) + queue: VecDeque<( + node_interner::FuncId, + FuncId, + TypeBindings, + Option, + bool, + Location, + )>, /// When a function finishes being monomorphized, the monomorphized ast::Function is /// stored here along with its FuncId. @@ -92,8 +98,16 @@ struct Monomorphizer<'interner> { return_location: Option, debug_type_tracker: DebugTypeTracker, + + in_unconstrained_function: bool, } +/// Using nested HashMaps here lets us avoid cloning HirTypes when calling .get() +type Functions = HashMap< + (node_interner::FuncId, /*is_unconstrained:*/ bool), + HashMap, FuncId>>, +>; + type HirType = crate::Type; /// Starting from the given `main` function, monomorphize the entire program, @@ -125,10 +139,12 @@ pub fn monomorphize_debug( let function_sig = monomorphizer.compile_main(main)?; while !monomorphizer.queue.is_empty() { - let (next_fn_id, new_id, bindings, trait_method, location) = + let (next_fn_id, new_id, bindings, trait_method, is_unconstrained, location) = monomorphizer.queue.pop_front().unwrap(); monomorphizer.locals.clear(); + monomorphizer.in_unconstrained_function = is_unconstrained; + perform_instantiation_bindings(&bindings); let interner = &monomorphizer.interner; let impl_bindings = perform_impl_bindings(interner, trait_method, next_fn_id, location) @@ -172,8 +188,8 @@ pub fn monomorphize_debug( impl<'interner> Monomorphizer<'interner> { fn new(interner: &'interner mut NodeInterner, debug_type_tracker: DebugTypeTracker) -> Self { Monomorphizer { - functions: HashMap::new(), - locals: HashMap::new(), + functions: HashMap::default(), + locals: HashMap::default(), queue: VecDeque::new(), finished_functions: BTreeMap::new(), next_local_id: 0, @@ -183,6 +199,7 @@ impl<'interner> Monomorphizer<'interner> { is_range_loop: false, return_location: None, debug_type_tracker, + in_unconstrained_function: false, } } @@ -207,14 +224,18 @@ impl<'interner> Monomorphizer<'interner> { id: node_interner::FuncId, expr_id: node_interner::ExprId, typ: &HirType, - turbofish_generics: Vec, + turbofish_generics: &[HirType], trait_method: Option, ) -> Definition { let typ = typ.follow_bindings(); + let turbofish_generics = vecmap(turbofish_generics, |typ| typ.follow_bindings()); + let is_unconstrained = self.is_unconstrained(id); + match self .functions - .get(&id) - .and_then(|inner_map| inner_map.get(&(typ.clone(), turbofish_generics.clone()))) + .get(&(id, is_unconstrained)) + .and_then(|inner_map| inner_map.get(&typ)) + .and_then(|inner_map| inner_map.get(&turbofish_generics)) { Some(id) => Definition::Function(*id), None => { @@ -257,14 +278,21 @@ impl<'interner> Monomorphizer<'interner> { } /// Prerequisite: typ = typ.follow_bindings() + /// and: turbofish_generics = vecmap(turbofish_generics, Type::follow_bindings) fn define_function( &mut self, id: node_interner::FuncId, typ: HirType, turbofish_generics: Vec, + is_unconstrained: bool, new_id: FuncId, ) { - self.functions.entry(id).or_default().insert((typ, turbofish_generics), new_id); + self.functions + .entry((id, is_unconstrained)) + .or_default() + .entry(typ) + .or_default() + .insert(turbofish_generics, new_id); } fn compile_main( @@ -874,7 +902,7 @@ impl<'interner> Monomorphizer<'interner> { *func_id, expr_id, &typ, - generics.unwrap_or_default(), + &generics.unwrap_or_default(), None, ); let typ = Self::convert_type(&typ, ident.location)?; @@ -1281,7 +1309,7 @@ impl<'interner> Monomorphizer<'interner> { .map_err(MonomorphizationError::InterpreterError)?; let func_id = - match self.lookup_function(func_id, expr_id, &function_type, vec![], Some(method)) { + match self.lookup_function(func_id, expr_id, &function_type, &[], Some(method)) { Definition::Function(func_id) => func_id, _ => unreachable!(), }; @@ -1546,12 +1574,19 @@ impl<'interner> Monomorphizer<'interner> { trait_method: Option, ) -> FuncId { let new_id = self.next_function_id(); - self.define_function(id, function_type.clone(), turbofish_generics, new_id); + let is_unconstrained = self.is_unconstrained(id); + self.define_function( + id, + function_type.clone(), + turbofish_generics, + is_unconstrained, + new_id, + ); let location = self.interner.expr_location(&expr_id); let bindings = self.interner.get_instantiation_bindings(expr_id); let bindings = self.follow_bindings(bindings); - self.queue.push_back((id, new_id, bindings, trait_method, location)); + self.queue.push_back((id, new_id, bindings, trait_method, is_unconstrained, location)); new_id } @@ -2007,6 +2042,11 @@ impl<'interner> Monomorphizer<'interner> { Ok(ast::Expression::Call(ast::Call { func, arguments, return_type, location })) } + + fn is_unconstrained(&self, func_id: node_interner::FuncId) -> bool { + self.in_unconstrained_function + || self.interner.function_modifiers(&func_id).is_unconstrained + } } fn unwrap_tuple_type(typ: &HirType) -> Vec {