From d0d4416376afc6aea1da9ac90b43fcb121687602 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Thu, 18 May 2023 02:19:34 -0400 Subject: [PATCH] feat: implement constant propagation pass This commit implements a pass which performs constant propagation on a Program, elimintating any variables which are constants by converting any references to constant variables into constant expressions. Likewise, expressions which are constant in all operands are folded into constant values and propagated as well. The resulting Program is simpler and easier to perform local analysis/optimizations on. --- parser/src/lib.rs | 20 + .../src/parser/tests/constant_propagation.rs | 116 +++++ parser/src/parser/tests/mod.rs | 1 + parser/src/parser/tests/utils.rs | 7 +- parser/src/transforms/constant_propagation.rs | 469 ++++++++++++++++++ parser/src/transforms/mod.rs | 3 + 6 files changed, 614 insertions(+), 2 deletions(-) create mode 100644 parser/src/parser/tests/constant_propagation.rs create mode 100644 parser/src/transforms/constant_propagation.rs create mode 100644 parser/src/transforms/mod.rs diff --git a/parser/src/lib.rs b/parser/src/lib.rs index a91c96bf..68de2e92 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -6,6 +6,7 @@ mod lexer; mod parser; mod sema; pub mod symbols; +pub mod transforms; pub use self::parser::{ParseError, Parser}; pub use self::symbols::Symbol; @@ -50,3 +51,22 @@ pub(crate) fn parse_module_from_file>( err @ Err(_) => err, } } + +/// Parses a [Module] from a file already in the codemap +/// +/// This is primarily intended for use in the import resolution phase. +pub(crate) fn parse_module( + diagnostics: &DiagnosticsHandler, + codemap: Arc, + source: Arc, +) -> Result { + let parser = Parser::new((), codemap); + match parser.parse::(diagnostics, source) { + ok @ Ok(_) => ok, + Err(ParseError::Lexer(err)) => { + diagnostics.emit(err); + Err(ParseError::Failed) + } + err @ Err(_) => err, + } +} diff --git a/parser/src/parser/tests/constant_propagation.rs b/parser/src/parser/tests/constant_propagation.rs new file mode 100644 index 00000000..82aca50e --- /dev/null +++ b/parser/src/parser/tests/constant_propagation.rs @@ -0,0 +1,116 @@ +use miden_diagnostics::SourceSpan; + +use pretty_assertions::assert_eq; + +use crate::{ast::*, transforms::ConstantPropagator}; + +use super::ParseTest; + +#[test] +fn test_constant_propagation() { + let root = r#" + def root + + use lib::* + + trace_columns: + main: [clk, a, b[2], c] + + public_inputs: + inputs: [0] + + const A = [2, 4, 6, 8] + const B = [[1, 1], [2, 2]] + + integrity_constraints: + enf test_constraint(b) + let x = 2^EXP + let y = A[0..2] + enf a + y[1] = c + (x + 1) + + boundary_constraints: + let x = B[0] + enf a.first = x[0] + + "#; + let lib = r#" + mod lib + + const EXP = 2 + + ev test_constraint([b0, b1]): + let x = EXP + let y = 2^x + enf b0 + x = b1 + y + "#; + + let test = ParseTest::new(); + let path = std::env::current_dir().unwrap().join("lib.air"); + test.add_virtual_file(path, lib.to_string()); + + let mut program = match test.parse_program(root) { + Err(err) => { + test.diagnostics.emit(err); + panic!("expected parsing to succeed, see diagnostics for details"); + } + Ok(ast) => ast, + }; + + let pass = ConstantPropagator::new(); + pass.run(&mut program).unwrap(); + + let mut expected = Program::new(ident!(root)); + expected.trace_columns.push(trace_segment!( + 0, + "$main", + [(clk, 1), (a, 1), (b, 2), (c, 1)] + )); + expected.public_inputs.insert( + ident!(inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(inputs), 0), + ); + expected + .constants + .insert(ident!(root, A), constant!(A = [2, 4, 6, 8])); + expected + .constants + .insert(ident!(root, B), constant!(B = [[1, 1], [2, 2]])); + expected + .constants + .insert(ident!(lib, EXP), constant!(EXP = 2)); + // When constant propagation is done, the boundary constraints should look like: + // enf a.first = 1 + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(a, Boundary::First, Type::Felt), + int!(1) + ))); + // When constant propagation is done, the integrity constraints should look like: + // enf test_constraint(b) + // enf a + 4 = c + 5 + expected + .integrity_constraints + .push(enforce!(call!(lib::test_constraint( + access!(b, Type::Vector(2)).into() + )))); + expected.integrity_constraints.push(enforce!(eq!( + add!(access!(a, Type::Felt), int!(4)), + add!(access!(c, Type::Felt), int!(5)) + ))); + // The test_constraint function should look like: + // enf b0 + 2 = b1 + 4 + let body = vec![enforce!(eq!( + add!(access!(b0, Type::Felt), int!(2)), + add!(access!(b1, Type::Felt), int!(4)) + ))]; + expected.evaluators.insert( + function_ident!(lib, test_constraint), + EvaluatorFunction::new( + SourceSpan::UNKNOWN, + ident!(test_constraint), + vec![trace_segment!(0, "%0", [(b0, 1), (b1, 1)])], + body, + ), + ); + + assert_eq!(program, expected); +} diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 707d7024..1c48e52c 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -601,6 +601,7 @@ macro_rules! import { mod arithmetic_ops; mod boundary_constraints; mod calls; +mod constant_propagation; mod constants; mod evaluators; mod identifiers; diff --git a/parser/src/parser/tests/utils.rs b/parser/src/parser/tests/utils.rs index 7e35a3de..332068c7 100644 --- a/parser/src/parser/tests/utils.rs +++ b/parser/src/parser/tests/utils.rs @@ -55,7 +55,7 @@ impl Emitter for SplitEmitter { /// - ParseError test: check that the parsed values are valid. /// * InvalidInt: This error is returned if the parsed number is not a valid u64. pub struct ParseTest { - diagnostics: Arc, + pub diagnostics: Arc, emitter: Arc, parser: Parser, } @@ -87,7 +87,10 @@ impl ParseTest { } } - #[allow(unused)] + pub fn add_virtual_file>(&self, name: P, content: String) { + self.parser.codemap.add(name.as_ref(), content); + } + pub fn parse_module_from_file(&self, path: &str) -> Result { self.parser .parse_file::(&self.diagnostics, path) diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs new file mode 100644 index 00000000..4c71260b --- /dev/null +++ b/parser/src/transforms/constant_propagation.rs @@ -0,0 +1,469 @@ +use std::{collections::HashMap, ops::ControlFlow}; + +use miden_diagnostics::{SourceSpan, Span, Spanned}; + +use crate::ast::{visit::VisitMut, *}; + +#[derive(Debug, thiserror::Error)] +pub enum InvalidConstantError { + #[error("this value is too large for an exponent")] + InvalidExponent(SourceSpan), +} + +/// This pass performs constant propagation on a [Program], replacing all uses of a constant +/// with the constant itself, converting accesses into constant aggregates with the accessed +/// value, replacing local variables bound to constants with the constant value, and folding +/// constant expressions into constant values. +/// +/// It is expected that the provided [Program] has already been run through semantic analysis, +/// so it will panic if it encounters invalid constructions to help catch bugs in the semantic +/// analysis pass, should they exist. +#[derive(Default)] +pub struct ConstantPropagator { + global: HashMap>, + local: HashMap>, + in_constraint_comprehension: bool, +} +impl ConstantPropagator { + #[inline(always)] + pub fn new() -> Self { + Self::default() + } + + pub fn run(mut self, program: &mut Program) -> Result<(), InvalidConstantError> { + self.global.reserve(program.constants.len()); + + match self.run_visitor(program) { + ControlFlow::Continue(()) => Ok(()), + ControlFlow::Break(err) => Err(err), + } + } + + fn run_visitor(&mut self, program: &mut Program) -> ControlFlow { + // Record all of the constant declarations + for (name, constant) in program.constants.iter() { + assert_eq!( + self.global + .insert(*name, Span::new(constant.span(), constant.value.clone())), + None + ); + } + + // Visit all of the evaluators + for evaluator in program.evaluators.values_mut() { + self.visit_mut_evaluator_function(evaluator)?; + } + + // Visit all of the constraints + self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?; + self.visit_mut_integrity_constraints(&mut program.integrity_constraints) + } +} +impl VisitMut for ConstantPropagator { + /// Fold constant expressions + fn visit_mut_scalar_expr( + &mut self, + expr: &mut ScalarExpr, + ) -> ControlFlow { + let span = expr.span(); + match expr { + // Expression is already folded + ScalarExpr::Const(_) => ControlFlow::Continue(()), + // Need to check if this access is to a constant value, and transform to a constant if so + ScalarExpr::SymbolAccess(sym) => { + let constant_value = match sym.name { + // Possibly a reference to a constant declaration + ResolvableIdentifier::Resolved(ref qid) => { + self.global.get(qid).cloned().map(|s| (s.span(), s.item)) + } + // Possibly a reference to a local bound to a constant + ResolvableIdentifier::Local(ref id) => { + self.local.get(id).cloned().map(|s| (s.span(), s.item)) + } + // Other identifiers cannot possibly be constant + _ => None, + }; + if let Some((span, constant_expr)) = constant_value { + match constant_expr { + ConstantExpr::Scalar(value) => { + assert_eq!(sym.access_type, AccessType::Default); + *expr = ScalarExpr::Const(Span::new(span, value)); + } + ConstantExpr::Vector(value) => match sym.access_type { + AccessType::Index(idx) => { + *expr = ScalarExpr::Const(Span::new(span, value[idx])); + } + ref ty => panic!( + "invalid constant reference, expected scalar access, got {:?}", + ty + ), + }, + ConstantExpr::Matrix(value) => match sym.access_type { + AccessType::Matrix(row, col) => { + *expr = ScalarExpr::Const(Span::new(span, value[row][col])); + } + ref ty => panic!( + "invalid constant reference, expected scalar access, got {:?}", + ty + ), + }, + } + } + ControlFlow::Continue(()) + } + // Fold constant expressions + ScalarExpr::Binary(BinaryExpr { + op, + ref mut lhs, + ref mut rhs, + .. + }) => { + // Visit operands first to ensure they are reduced to constants if possible + self.visit_mut_scalar_expr(lhs)?; + self.visit_mut_scalar_expr(rhs)?; + // If both operands are constant, fold + if let (ScalarExpr::Const(l), ScalarExpr::Const(r)) = (lhs.as_mut(), rhs.as_mut()) { + let folded = match op { + BinaryOp::Add => l.item + r.item, + BinaryOp::Sub => l.item - r.item, + BinaryOp::Mul => l.item * r.item, + BinaryOp::Exp => match r.item.try_into() { + Ok(exp) => l.item.pow(exp), + Err(_) => { + return ControlFlow::Break(InvalidConstantError::InvalidExponent( + span, + )) + } + }, + // This op cannot be folded + BinaryOp::Eq => return ControlFlow::Continue(()), + }; + *expr = ScalarExpr::Const(Span::new(span, folded)); + } + ControlFlow::Continue(()) + } + // While calls cannot be constant folded, arguments can be + ScalarExpr::Call(ref mut call) => self.visit_mut_call(call), + // This cannot be constant folded + ScalarExpr::BoundedSymbolAccess(_) => ControlFlow::Continue(()), + } + } + + fn visit_mut_expr(&mut self, expr: &mut Expr) -> ControlFlow { + let span = expr.span(); + match expr { + // Already constant + Expr::Const(_) => ControlFlow::Continue(()), + // Lift to `Expr::Const` if the scalar expression is constant + // + // We deal with symbol accesses directly, as they may evaluate to an aggregate constant + Expr::Scalar(ScalarExpr::SymbolAccess(ref mut access)) => { + let constant_value = match access.name { + // Possibly a reference to a constant declaration + ResolvableIdentifier::Resolved(ref qid) => { + self.global.get(qid).cloned().map(|s| (s.span(), s.item)) + } + // Possibly a reference to a local bound to a constant + ResolvableIdentifier::Local(ref id) => { + self.local.get(id).cloned().map(|s| (s.span(), s.item)) + } + // Other identifiers cannot possibly be constant + _ => None, + }; + if let Some((span, constant_expr)) = constant_value { + match constant_expr { + cexpr @ ConstantExpr::Scalar(_) => { + assert_eq!(access.access_type, AccessType::Default); + *expr = Expr::Const(Span::new(span, cexpr)); + } + ConstantExpr::Vector(value) => match access.access_type.clone() { + AccessType::Default => { + *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(value))); + } + AccessType::Slice(range) => { + let vector = value[range.start..range.end].to_vec(); + *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector))); + } + AccessType::Index(idx) => { + *expr = + Expr::Const(Span::new(span, ConstantExpr::Scalar(value[idx]))); + } + ref ty => panic!( + "invalid constant reference, expected scalar access, got {:?}", + ty + ), + }, + ConstantExpr::Matrix(value) => match access.access_type.clone() { + AccessType::Default => { + *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(value))); + } + AccessType::Slice(range) => { + let matrix = value[range.start..range.end].to_vec(); + *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix))); + } + AccessType::Index(idx) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Vector(value[idx].clone()), + )); + } + AccessType::Matrix(row, col) => { + *expr = Expr::Const(Span::new( + span, + ConstantExpr::Scalar(value[row][col]), + )); + } + }, + } + } + ControlFlow::Continue(()) + } + Expr::Scalar(ref mut scalar) => { + self.visit_mut_scalar_expr(scalar)?; + if let ScalarExpr::Const(value) = scalar { + let value = Expr::Const(Span::new(span, ConstantExpr::Scalar(value.item))); + *expr = value; + } + ControlFlow::Continue(()) + } + // Ranges are constant + Expr::Range(_) => ControlFlow::Continue(()), + // Visit vector elements, and promote the vector to `Expr::Const` if possible + Expr::Vector(ref mut vector) => { + let mut is_constant = true; + for elem in vector.iter_mut() { + self.visit_mut_scalar_expr(elem)?; + is_constant &= elem.is_constant(); + } + if is_constant { + let vector = ConstantExpr::Vector( + vector + .iter() + .map(|sexpr| match sexpr { + ScalarExpr::Const(elem) => elem.item, + _ => unreachable!(), + }) + .collect(), + ); + *expr = Expr::Const(Span::new(span, vector)); + } + ControlFlow::Continue(()) + } + // Visit matrix elements, and promote the matrix to `Expr::Const` if possible + Expr::Matrix(ref mut matrix) => { + let mut is_constant = true; + for row in matrix.iter_mut() { + for column in row.iter_mut() { + self.visit_mut_scalar_expr(column)?; + is_constant &= column.is_constant(); + } + } + if is_constant { + let matrix = ConstantExpr::Matrix( + matrix + .iter() + .map(|row| { + row.iter() + .map(|col| match col { + ScalarExpr::Const(elem) => elem.item, + _ => unreachable!(), + }) + .collect::>() + }) + .collect(), + ); + *expr = Expr::Const(Span::new(span, matrix)); + } + ControlFlow::Continue(()) + } + // Visit list comprehensions and convert to constant if possible + Expr::ListComprehension(ref mut lc) => { + let mut has_constant_iterables = true; + for iterable in lc.iterables.iter_mut() { + self.visit_mut_expr(iterable)?; + has_constant_iterables &= iterable.is_constant(); + } + + // If we have constant iterables, drive the comprehension, evaluating it at + // each step. If any part of the body cannot be compile-time evaluated, then + // we bail early, as the comprehension can only be folded if all parts of it + // are constant. + if !has_constant_iterables { + return ControlFlow::Continue(()); + } + + // Start a new lexical scope + let prev = self.local.clone(); + + // All iterables must be the same length, so determine the number of + // steps based on the length of the first iterable + let max_len = match &lc.iterables[0] { + Expr::Const(Span { + item: ConstantExpr::Vector(elems), + .. + }) => elems.len(), + Expr::Const(Span { + item: ConstantExpr::Matrix(rows), + .. + }) => rows.len(), + Expr::Const(_) => panic!("expected iterable constant, got scalar"), + Expr::Range(range) => range.end - range.start, + _ => unreachable!(), + }; + + // Drive the comprehension step-by-step + let mut folded = vec![]; + for step in 0..max_len { + for (binding, iterable) in lc.bindings.iter().copied().zip(lc.iterables.iter()) + { + let span = iterable.span(); + match iterable { + Expr::Const(Span { + item: ConstantExpr::Vector(elems), + .. + }) => { + let value = ConstantExpr::Scalar(elems[step]); + self.local.insert(binding, Span::new(span, value)); + } + Expr::Const(Span { + item: ConstantExpr::Matrix(elems), + .. + }) => { + let value = ConstantExpr::Vector(elems[step].clone()); + self.local.insert(binding, Span::new(span, value)); + } + Expr::Range(range) => { + assert!(range.end > range.start + step); + let value = ConstantExpr::Scalar((range.start + step) as u64); + self.local.insert(binding, Span::new(span, value)); + } + _ => unreachable!(), + } + } + + if let Some(mut selector) = lc.selector.as_ref().cloned() { + self.visit_mut_scalar_expr(&mut selector)?; + match selector { + ScalarExpr::Const(selected) => { + // If the selector returns false on this iteration, go to the next step + if *selected == 0 { + continue; + } + } + // The selector cannot be evaluated, bail out early + _ => return ControlFlow::Continue(()), + } + } + + let mut body = lc.body.as_ref().clone(); + self.visit_mut_scalar_expr(&mut body)?; + + // If the body is constant, store the result in the vector, otherwise we must + // bail because this comprehension cannot be folded + if let ScalarExpr::Const(folded_body) = body { + folded.push(folded_body.item); + } else { + return ControlFlow::Continue(()); + } + } + + // Exit lexical scope + self.local = prev; + + // If we reach here, the comprehension was expanded to a constant vector + *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(folded))); + ControlFlow::Continue(()) + } + } + } + + fn visit_mut_statement_block( + &mut self, + statements: &mut Vec, + ) -> ControlFlow { + let mut current_statement = 0; + + let mut buffer = vec![]; + while current_statement < statements.len() { + let num_statements = statements.len(); + match &mut statements[current_statement] { + Statement::Let(ref mut expr) => { + // A `let` may only appear once in a statement block, and must be the + // last statement in the block + assert_eq!( + current_statement, + num_statements - 1, + "let is not in tail position of block" + ); + // Visit the binding expression first + self.visit_mut_expr(&mut expr.value)?; + // Enter a new lexical scope + let prev = self.local.clone(); + // If the value is constant, record it in our bindings map + let is_constant = expr.value.is_constant(); + if is_constant { + match expr.value { + Expr::Const(ref value) => { + self.local.insert(expr.name, value.clone()); + } + Expr::Range(ref range) => { + let vector = + range.item.clone().into_iter().map(|i| i as u64).collect(); + self.local.insert( + expr.name, + Span::new(range.span(), ConstantExpr::Vector(vector)), + ); + } + _ => unreachable!(), + } + } + + // Visit the let body + self.visit_mut_statement_block(&mut expr.body)?; + + // If this let is constant, then the binding is no longer + // used in the body after constant propagation, flatten its + // body into the current block. + if is_constant { + buffer.append(&mut expr.body); + } + + // Restore the previous scope + self.local = prev; + } + Statement::Enforce(ref mut expr) => { + self.visit_mut_enforce(expr)?; + } + Statement::EnforceAll(ref mut expr) => { + self.in_constraint_comprehension = true; + self.visit_mut_list_comprehension(expr)?; + self.in_constraint_comprehension = false; + } + } + + // If we have a non-empty buffer, then we are collapsing a let into the current block, + // and that let must have been the last expression in the block, so as soon as we fold + // its body into the current block, we're done + if buffer.is_empty() { + current_statement += 1; + continue; + } + + // Drop the let statement being folded in to this block + statements.pop(); + + // Append the buffer + statements.append(&mut buffer); + + // We're done + break; + } + + ControlFlow::Continue(()) + } + + /// It should not be possible to reach this, as we handle statements at the block level + fn visit_mut_statement(&mut self, _: &mut Statement) -> ControlFlow { + panic!("unexpectedly reached visit_mut_statement"); + } +} diff --git a/parser/src/transforms/mod.rs b/parser/src/transforms/mod.rs new file mode 100644 index 00000000..2cabc327 --- /dev/null +++ b/parser/src/transforms/mod.rs @@ -0,0 +1,3 @@ +mod constant_propagation; + +pub use self::constant_propagation::{ConstantPropagator, InvalidConstantError};