diff --git a/crates/rue-compiler/src/codegen.rs b/crates/rue-compiler/src/codegen.rs index a0aaad1..1a00543 100644 --- a/crates/rue-compiler/src/codegen.rs +++ b/crates/rue-compiler/src/codegen.rs @@ -76,7 +76,6 @@ impl<'a> Codegen<'a> { Lir::Run(program, args) => self.gen_run(program, args), Lir::Curry(body, args) => self.gen_apply(body, args), Lir::Closure(body, args) => self.gen_closure(body, args), - Lir::FunctionBody(body) => self.gen_quote(body), Lir::First(value) => self.gen_first(value), Lir::Rest(value) => self.gen_rest(value), Lir::Raise(value) => self.gen_raise(value), diff --git a/crates/rue-compiler/src/lir.rs b/crates/rue-compiler/src/lir.rs index 05b353b..de3c00a 100644 --- a/crates/rue-compiler/src/lir.rs +++ b/crates/rue-compiler/src/lir.rs @@ -9,7 +9,6 @@ pub enum Lir { Quote(LirId), Curry(LirId, Vec), Closure(LirId, Vec), - FunctionBody(LirId), First(LirId), Rest(LirId), Raise(Option), diff --git a/crates/rue-compiler/src/optimizer.rs b/crates/rue-compiler/src/optimizer.rs index 4730119..9816379 100644 --- a/crates/rue-compiler/src/optimizer.rs +++ b/crates/rue-compiler/src/optimizer.rs @@ -14,6 +14,7 @@ mod environment; pub use dependency_graph::*; pub use environment::*; +use indexmap::IndexSet; pub struct Optimizer<'a> { db: &'a mut Database, @@ -34,21 +35,10 @@ impl<'a> Optimizer<'a> { let Symbol::Function(fun) = self.db.symbol(main).clone() else { unreachable!(); }; - let env_id = self.graph.env(fun.scope_id); - let body = self.opt_hir(env_id, fun.hir_id); - - let mut args = Vec::new(); - - for symbol_id in self.db.env(env_id).definitions() { - args.push(self.opt_definition(env_id, symbol_id)); - } - - for symbol_id in self.db.env(env_id).captures() { - args.push(self.opt_definition(env_id, symbol_id)); - } - - self.db.alloc_lir(Lir::Curry(body, args)) + let mut definitions = self.db.env(env_id).definitions(); + definitions.extend(self.db.env(env_id).captures()); + self.opt_definitions(env_id, definitions, fun.hir_id) } fn opt_path(&mut self, env_id: EnvironmentId, symbol_id: SymbolId) -> LirId { @@ -59,7 +49,7 @@ impl<'a> Optimizer<'a> { } let mut current_env_id = env_id; - let mut environment = self.db.env(env_id).build().clone(); + let mut environment: Vec = self.db.env(env_id).build().into_iter().collect(); while let Some(parent_env_id) = self.db.env(current_env_id).parent() { assert!(self.db.env(current_env_id).parameters().is_empty()); @@ -99,19 +89,12 @@ impl<'a> Optimizer<'a> { hir_id, scope_id, .. }) => { let function_env_id = self.graph.env(scope_id); - - let mut body = self.opt_hir(function_env_id, hir_id); - let mut definitions = Vec::new(); - - for symbol_id in self.db.env(function_env_id).definitions() { - definitions.push(self.opt_definition(function_env_id, symbol_id)); - } - - if !definitions.is_empty() { - body = self.db.alloc_lir(Lir::Curry(body, definitions)); - } - - self.db.alloc_lir(Lir::FunctionBody(body)) + let function = self.opt_definitions( + function_env_id, + self.db.env(function_env_id).definitions(), + hir_id, + ); + self.db.alloc_lir(Lir::Quote(function)) } Symbol::Const(Value { hir_id, .. }) => self.opt_hir(env_id, hir_id), Symbol::Let(symbol) if self.graph.symbol_usages(symbol_id) > 0 => { @@ -126,6 +109,52 @@ impl<'a> Optimizer<'a> { } } + fn opt_definitions( + &mut self, + mut env_id: EnvironmentId, + definitions: IndexSet, + body: HirId, + ) -> LirId { + let mut remaining: IndexSet = definitions.into_iter().collect(); + let mut curries = Vec::new(); + + while !remaining.is_empty() { + let no_references = remaining + .iter() + .filter(|&symbol_id| { + !self.db.symbol(*symbol_id).is_constant() + || self + .graph + .constant_references(*symbol_id) + .intersection(&remaining) + .count() + == 0 + }) + .copied() + .collect::>(); + + let mut args = Vec::new(); + + for &symbol_id in &no_references { + args.push(self.opt_definition(env_id, symbol_id)); + } + + curries.push(args); + remaining.retain(|&symbol_id| !no_references.contains(&symbol_id)); + if !remaining.is_empty() { + env_id = self.db.alloc_env(Environment::binding(env_id)); + } + } + + let mut body = self.opt_hir(env_id, body); + + for args in curries.into_iter().rev() { + body = self.db.alloc_lir(Lir::Curry(body, args)); + } + + body + } + fn opt_hir(&mut self, env_id: EnvironmentId, hir_id: HirId) -> LirId { match self.db.hir(hir_id).clone() { Hir::Unknown => self.db.alloc_lir(Lir::Atom(Vec::new())), @@ -134,18 +163,6 @@ impl<'a> Optimizer<'a> { Hir::Reference(symbol_id, ..) => self.opt_reference(env_id, symbol_id), Hir::CheckExists(value) => self.opt_check_exists(env_id, value), Hir::Definition { scope_id, hir_id } => { - let definition_env_id = self.graph.env(scope_id); - for symbol_id in self.db.env_mut(definition_env_id).definitions() { - let Symbol::Let(..) = self.db.symbol(symbol_id) else { - continue; - }; - - if self.graph.symbol_usages(symbol_id) == 1 { - self.db - .env_mut(definition_env_id) - .remove_definition(symbol_id); - } - } self.opt_env_definition(env_id, scope_id, hir_id) } Hir::FunctionCall { @@ -239,6 +256,17 @@ impl<'a> Optimizer<'a> { scope_id: ScopeId, hir_id: HirId, ) -> LirId { + let definition_env_id = self.graph.env(scope_id); + for symbol_id in self.db.env_mut(definition_env_id).definitions() { + let Symbol::Let(..) = self.db.symbol(symbol_id) else { + continue; + }; + if self.graph.symbol_usages(symbol_id) == 1 { + self.db + .env_mut(definition_env_id) + .remove_definition(symbol_id); + } + } let child_env_id = self.graph.env(scope_id); let body = self.opt_hir(child_env_id, hir_id); @@ -376,20 +404,6 @@ impl<'a> Optimizer<'a> { let mut inline_parameter_map = HashMap::new(); let mut args = args; - let env = self.db.env(function_env_id).clone(); - for symbol_id in env.definitions() { - if self.db.symbol(symbol_id).is_parameter() { - continue; - } - self.db.env_mut(env_id).define(symbol_id); - } - for symbol_id in env.captures() { - if self.db.symbol(symbol_id).is_parameter() { - continue; - } - self.db.env_mut(env_id).capture(symbol_id); - } - let param_len = self.db.env(function_env_id).parameters().len(); for (i, symbol_id) in self diff --git a/crates/rue-compiler/src/optimizer/dependency_graph.rs b/crates/rue-compiler/src/optimizer/dependency_graph.rs index e90cd3e..205c5a3 100644 --- a/crates/rue-compiler/src/optimizer/dependency_graph.rs +++ b/crates/rue-compiler/src/optimizer/dependency_graph.rs @@ -18,6 +18,7 @@ use super::Environment; pub struct DependencyGraph { env: IndexMap, symbol_usages: HashMap, + constant_references: HashMap>, } impl DependencyGraph { @@ -34,7 +35,14 @@ impl DependencyGraph { } pub fn symbol_usages(&self, symbol_id: SymbolId) -> usize { - *self.symbol_usages.get(&symbol_id).unwrap_or(&0) + self.symbol_usages.get(&symbol_id).copied().unwrap_or(0) + } + + pub fn constant_references(&self, symbol_id: SymbolId) -> IndexSet { + self.constant_references + .get(&symbol_id) + .cloned() + .unwrap_or_default() } } @@ -42,7 +50,7 @@ struct GraphTraversal<'a> { db: &'a mut Database, graph: DependencyGraph, edges: HashMap>, - constant_reference_stack: HashSet, + constant_reference_stack: IndexSet, } impl<'a> GraphTraversal<'a> { @@ -51,7 +59,7 @@ impl<'a> GraphTraversal<'a> { db, graph: DependencyGraph::default(), edges: HashMap::new(), - constant_reference_stack: HashSet::new(), + constant_reference_stack: IndexSet::new(), } } @@ -224,6 +232,14 @@ impl<'a> GraphTraversal<'a> { ) { let symbol = self.db.symbol(symbol_id).clone(); + for &definition_id in &self.constant_reference_stack { + self.graph + .constant_references + .entry(definition_id) + .or_default() + .insert(symbol_id); + } + if symbol.is_constant() && !self.constant_reference_stack.insert(symbol_id) { return; } @@ -231,7 +247,7 @@ impl<'a> GraphTraversal<'a> { self.graph .symbol_usages .entry(symbol_id) - .and_modify(|count| *count += 1) + .and_modify(|usages| *usages += 1) .or_insert(1); self.propagate_capture(scope_id, symbol_id, &mut HashSet::new()); @@ -265,16 +281,34 @@ impl<'a> GraphTraversal<'a> { // Functions are visited in the scope in which they are defined. // TODO: For inline functions, should this be visited in the current scope? - Symbol::Function(fun) | Symbol::InlineFunction(fun) => { + Symbol::Function(fun) => { self.visit_hir(fun.scope_id, fun.hir_id, visited); } + Symbol::InlineFunction(fun) => { + self.visit_hir(fun.scope_id, fun.hir_id, visited); + + let env = self.db.env(self.graph.env(fun.scope_id)).clone(); + for symbol_id in env.definitions() { + if self.db.symbol(symbol_id).is_parameter() { + continue; + } + self.db.env_mut(self.graph.env(scope_id)).define(symbol_id); + } + for symbol_id in env.captures() { + if self.db.symbol(symbol_id).is_parameter() { + continue; + } + self.db.env_mut(self.graph.env(scope_id)).capture(symbol_id); + } + } + // Parameters don't need to be visited, since currently // they are just a reference to the environment. Symbol::Parameter(..) => {} } - self.constant_reference_stack.remove(&symbol_id); + self.constant_reference_stack.shift_remove(&symbol_id); } fn compute_edges(&mut self, symbol_id: SymbolId) { @@ -400,7 +434,7 @@ impl<'a> GraphTraversal<'a> { } } - self.constant_reference_stack.remove(&symbol_id); + self.constant_reference_stack.shift_remove(&symbol_id); } } } diff --git a/crates/rue-compiler/src/optimizer/environment.rs b/crates/rue-compiler/src/optimizer/environment.rs index e48e8f8..07e1eec 100644 --- a/crates/rue-compiler/src/optimizer/environment.rs +++ b/crates/rue-compiler/src/optimizer/environment.rs @@ -39,16 +39,16 @@ impl Environment { self.captured_symbols.insert(symbol_id); } - pub fn definitions(&self) -> Vec { - self.defined_symbols.iter().copied().collect() + pub fn definitions(&self) -> IndexSet { + self.defined_symbols.clone() } - pub fn captures(&self) -> Vec { - self.captured_symbols.iter().copied().collect() + pub fn captures(&self) -> IndexSet { + self.captured_symbols.clone() } - pub fn parameters(&self) -> Vec { - self.parameters.iter().copied().collect() + pub fn parameters(&self) -> IndexSet { + self.parameters.clone() } pub fn rest_parameter(&self) -> bool { @@ -59,8 +59,8 @@ impl Environment { self.parent } - pub fn build(&self) -> Vec { - let mut symbol_ids = Vec::new(); + pub fn build(&self) -> IndexSet { + let mut symbol_ids = IndexSet::new(); symbol_ids.extend(self.defined_symbols.iter().copied()); if self.parent.is_none() {