From bda7dc5bd8caaa16ce434debc422ffbe4a3e8532 Mon Sep 17 00:00:00 2001 From: hmuro andrej Date: Sat, 25 Feb 2023 19:06:22 +0300 Subject: [PATCH] code improvement --- air-script/Cargo.toml | 2 +- air-script/tests/main.rs | 1 + codegen/gce/Cargo.toml | 2 +- codegen/gce/src/error.rs | 8 +- codegen/gce/src/expressions.rs | 173 ++++++++++++++++++--------------- codegen/gce/src/lib.rs | 44 ++++----- 6 files changed, 124 insertions(+), 106 deletions(-) diff --git a/air-script/Cargo.toml b/air-script/Cargo.toml index 42b4bb80..c6b82f7a 100644 --- a/air-script/Cargo.toml +++ b/air-script/Cargo.toml @@ -18,7 +18,7 @@ path = "src/main.rs" [dependencies] codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.2.0" } -codegen-gce = { package = "air-codegen-gce", path = "../codegen/gce", version = "0.1.0" } +codegen-gce = { package = "air-codegen-gce", path = "../codegen/gce", version = "0.2.0" } env_logger = "0.10.0" ir = { package = "air-ir", path = "../ir", version = "0.2.0" } log = { version = "0.4", default-features = false } diff --git a/air-script/tests/main.rs b/air-script/tests/main.rs index 8b7e7f92..8e0db34a 100644 --- a/air-script/tests/main.rs +++ b/air-script/tests/main.rs @@ -1,6 +1,7 @@ use expect_test::expect_file; use std::fs::{self, File}; use std::io::prelude::*; + mod helpers; use helpers::Test; diff --git a/codegen/gce/Cargo.toml b/codegen/gce/Cargo.toml index ec8562ff..08b493aa 100644 --- a/codegen/gce/Cargo.toml +++ b/codegen/gce/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-codegen-gce" -version = "0.1.0" +version = "0.2.0" description="Code generation for the generic constraint evaluation format." authors = ["miden contributors"] readme="README.md" diff --git a/codegen/gce/src/error.rs b/codegen/gce/src/error.rs index f96aae79..f841767f 100644 --- a/codegen/gce/src/error.rs +++ b/codegen/gce/src/error.rs @@ -10,23 +10,23 @@ pub enum ConstraintEvaluationError { impl ConstraintEvaluationError { pub fn invalid_trace_segment(segment: u8) -> Self { ConstraintEvaluationError::InvalidTraceSegment(format!( - "Trace segment {segment} is invalid" + "Trace segment {segment} is invalid." )) } pub fn constant_not_found(name: &str) -> Self { - ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{name}\" not found")) + ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{name}\" not found.")) } pub fn invalid_constant_type(name: &str, constant_type: &str) -> Self { ConstraintEvaluationError::InvalidConstantType(format!( - "Invalid type of constant \"{name}\". {constant_type} exprected." + "Invalid type of constant \"{name}\". Expected \"{constant_type}\"." )) } pub fn operation_not_found(index: usize) -> Self { ConstraintEvaluationError::OperationNotFound(format!( - "Operation with index {index} does not match the expression in the expressions JSON array" + "Operation with index {index} does not match the expression in the JSON expressions array." )) } } diff --git a/codegen/gce/src/expressions.rs b/codegen/gce/src/expressions.rs index b5027236..6d9ebf4f 100644 --- a/codegen/gce/src/expressions.rs +++ b/codegen/gce/src/expressions.rs @@ -15,108 +15,126 @@ use std::collections::BTreeMap; const MAIN_TRACE_SEGMENT_INDEX: u8 = 0; -pub struct ExpressionsHandler<'a> { - ir: &'a AirIR, - constants: &'a [u64], +pub struct GceBuilder { // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array - expressions_map: &'a mut BTreeMap, + expressions_map: BTreeMap, + expressions: Vec, + outputs: Vec, } -impl<'a> ExpressionsHandler<'a> { - pub fn new( - ir: &'a AirIR, - constants: &'a [u64], - expressions_map: &'a mut BTreeMap, - ) -> Self { - ExpressionsHandler { - ir, - constants, - expressions_map, +impl GceBuilder { + pub fn new() -> Self { + GceBuilder { + expressions_map: BTreeMap::new(), + expressions: Vec::new(), + outputs: Vec::new(), } } + pub fn build( + &mut self, + ir: &AirIR, + constants: &[u64], + ) -> Result<(), ConstraintEvaluationError> { + self.build_expressions(ir, constants)?; + self.build_outputs(ir)?; + Ok(()) + } + + pub fn into_gce(self) -> Result<(Vec, Vec), ConstraintEvaluationError> { + Ok((self.expressions, self.outputs)) + } + /// Parses expressions in transition graph's Node vector, creates [Expression] instances and pushes /// them to the `expressions` vector. - pub fn get_expressions(&mut self) -> Result, ConstraintEvaluationError> { + fn build_expressions( + &mut self, + ir: &AirIR, + constants: &[u64], + ) -> Result<(), ConstraintEvaluationError> { // TODO: currently we can't create a node reference to the last row (which is required for // main.last and aux.last boundary constraints). Working in assumption that first reference to // the column is .first constraint and second is .last constraint (in the boundary section, not // entire array) - let mut expressions = Vec::new(); - for (index, node) in self.ir.constraint_graph().nodes().iter().enumerate() { + for (index, node) in ir.constraint_graph().nodes().iter().enumerate() { match node.op() { Operation::Add(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Add, *l, *r, )?); // create mapping (index in node graph: index in expressions vector) - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Sub(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Sub, *l, *r, )?); - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Mul(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Mul, *l, *r, )?); - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Exp(i, degree) => { match degree { 0 => { // I decided that node^0 could be emulated using the product of 1*1, but perhaps there are better ways - let index_of_1 = get_constant_index_by_value(1, self.constants)?; + let index_of_1 = get_constant_index_by_value(1, constants)?; let const_1_node = NodeReference { node_type: NodeType::Const, index: index_of_1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: const_1_node.clone(), rhs: const_1_node, }); } 1 => { - let lhs = self.handle_node_reference(*i)?; - let degree_index = get_constant_index_by_value(1, self.constants)?; + let lhs = self.handle_node_reference(ir, constants, *i)?; + let degree_index = get_constant_index_by_value(1, constants)?; let rhs = NodeReference { node_type: NodeType::Const, index: degree_index, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs, rhs, }); } - _ => self.handle_exponentiation(&mut expressions, *i, *degree)?, + _ => self.handle_exponentiation(ir, constants, *i, *degree)?, } - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } _ => {} } } - Ok(expressions) + Ok(()) } /// Fills the `outputs` vector with indexes from `expressions` vector according to the `expressions_map`. - pub fn get_outputs( - &self, - expressions: &mut Vec, - ) -> Result, ConstraintEvaluationError> { - let mut outputs = Vec::new(); - - for i in 0..self.ir.segment_widths().len() { - for root in self.ir.boundary_constraints(i as u8) { + fn build_outputs(&mut self, ir: &AirIR) -> Result<(), ConstraintEvaluationError> { + for i in 0..ir.segment_widths().len() { + for root in ir.boundary_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) @@ -126,35 +144,35 @@ impl<'a> ExpressionsHandler<'a> { // if we found index twice, put the corresponding expression in the expressions // array again. It means that we have equal boundary constraints for both first // and last domains (e.g. a.first = 1 and a.last = 1) - if outputs.contains(index) { - expressions.push(expressions[*index].clone()); - outputs.push(expressions.len() - 1); + if self.outputs.contains(index) { + self.expressions.push(self.expressions[*index].clone()); + self.outputs.push(self.expressions.len() - 1); } else { - outputs.push(*index); + self.outputs.push(*index); } } - for root in self.ir.validity_constraints(i as u8) { + for root in ir.validity_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; - outputs.push(*index); + self.outputs.push(*index); } - for root in self.ir.transition_constraints(i as u8) { + for root in ir.transition_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; - outputs.push(*index); + self.outputs.push(*index); } } - Ok(outputs) + Ok(()) } // --- HELPERS -------------------------------------------------------------------------------- @@ -162,12 +180,14 @@ impl<'a> ExpressionsHandler<'a> { /// Parses expression in transition graph Node vector and returns related [Expression] instance. fn handle_transition_expression( &self, + ir: &AirIR, + constants: &[u64], op: ExpressionOperation, l: NodeIndex, r: NodeIndex, ) -> Result { - let lhs = self.handle_node_reference(l)?; - let rhs = self.handle_node_reference(r)?; + let lhs = self.handle_node_reference(ir, constants, l)?; + let rhs = self.handle_node_reference(ir, constants, r)?; Ok(ExpressionJson { op, lhs, rhs }) } @@ -175,10 +195,12 @@ impl<'a> ExpressionsHandler<'a> { /// [NodeReference] instance. fn handle_node_reference( &self, + ir: &AirIR, + constants: &[u64], i: NodeIndex, ) -> Result { use Operation::*; - match self.ir.constraint_graph().node(&i).op() { + match ir.constraint_graph().node(&i).op() { Add(_, _) | Sub(_, _) | Mul(_, _) | Exp(_, _) => { let index = self .expressions_map @@ -192,14 +214,14 @@ impl<'a> ExpressionsHandler<'a> { Constant(constant_value) => { match constant_value { ConstantValue::Inline(v) => { - let index = get_constant_index_by_value(*v, self.constants)?; + let index = get_constant_index_by_value(*v, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, }) } ConstantValue::Scalar(name) => { - let index = get_constant_index_by_name(self.ir, name, self.constants)?; + let index = get_constant_index_by_name(ir, name, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, @@ -208,22 +230,16 @@ impl<'a> ExpressionsHandler<'a> { ConstantValue::Vector(vector_access) => { // why Constant.name() returns Identifier and VectorAccess.name() works like // VectorAccess.name.name() and returns &str? (same with MatrixAccess) - let index = get_constant_index_by_vector_access( - self.ir, - vector_access, - self.constants, - )?; + let index = + get_constant_index_by_vector_access(ir, vector_access, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, }) } ConstantValue::Matrix(matrix_access) => { - let index = get_constant_index_by_matrix_access( - self.ir, - matrix_access, - self.constants, - )?; + let index = + get_constant_index_by_matrix_access(ir, matrix_access, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, @@ -248,8 +264,8 @@ impl<'a> ExpressionsHandler<'a> { }) } } - i if i < self.ir.segment_widths().len() as u8 => { - let col_index = self.ir.segment_widths()[0..i as usize].iter().sum::() + i if i < ir.segment_widths().len() as u8 => { + let col_index = ir.segment_widths()[0..i as usize].iter().sum::() as usize + trace_access.col_idx(); if trace_access.row_offset() == 0 { @@ -270,14 +286,14 @@ impl<'a> ExpressionsHandler<'a> { } } RandomValue(rand_index) => { - let index = get_random_value_index(self.ir, rand_index); + let index = get_random_value_index(ir, rand_index); Ok(NodeReference { node_type: NodeType::Var, index, }) } PublicInput(name, public_index) => { - let index = get_public_input_index(self.ir, name, public_index); + let index = get_public_input_index(ir, name, public_index); Ok(NodeReference { node_type: NodeType::Var, index, @@ -291,20 +307,21 @@ impl<'a> ExpressionsHandler<'a> { /// Replaces the exponentiation operation with multiplication operations, adding them to the /// expressions vector. fn handle_exponentiation( - &self, - expressions: &mut Vec, + &mut self, + ir: &AirIR, + constants: &[u64], i: NodeIndex, degree: usize, ) -> Result<(), ConstraintEvaluationError> { // base node that we want to raise to a degree - let base_node = self.handle_node_reference(i)?; + let base_node = self.handle_node_reference(ir, constants, i)?; // push node^2 expression - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: base_node.clone(), rhs: base_node.clone(), }); - let square_node_index = expressions.len() - 1; + let square_node_index = self.expressions.len() - 1; // square the previous expression while there is such an opportunity let mut cur_degree_of_2 = 1; // currently we have node^(2^cur_degree_of_2) = node^(2^1) = node^2 @@ -312,9 +329,9 @@ impl<'a> ExpressionsHandler<'a> { // the last node that we want to square let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node.clone(), rhs: last_node, @@ -330,9 +347,9 @@ impl<'a> ExpressionsHandler<'a> { // if we need to add first degree (base node) let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node, rhs: base_node, @@ -342,14 +359,14 @@ impl<'a> ExpressionsHandler<'a> { if 2_usize.pow(cur_degree_of_2 - 1) <= diff { let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; let fitting_degree_of_2_node = NodeReference { node_type: NodeType::Expr, // cur_degree_of_2 shows how many indexes we need to add to reach the largest fitting degree of 2 index: square_node_index + cur_degree_of_2 as usize - 2, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node, rhs: fitting_degree_of_2_node, diff --git a/codegen/gce/src/lib.rs b/codegen/gce/src/lib.rs index bc97f306..e3a6b826 100644 --- a/codegen/gce/src/lib.rs +++ b/codegen/gce/src/lib.rs @@ -1,13 +1,14 @@ -use ir::{ - constraints::{ConstantValue, Operation}, - AirIR, -}; - pub use air_script_core::{ Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, MatrixAccess, NamedTraceAccess, TraceSegment, Variable, VariableType, VectorAccess, }; +use ir::{ + constraints::{ConstantValue, Operation}, + AirIR, +}; use std::fmt::Display; +use std::fs::File; +use std::io::Write; mod error; use error::ConstraintEvaluationError; @@ -15,13 +16,10 @@ use error::ConstraintEvaluationError; mod utils; mod expressions; -use expressions::ExpressionsHandler; - -use std::collections::BTreeMap; -use std::fs::File; -use std::io::Write; +use expressions::GceBuilder; -/// Holds data for JSON generation +/// CodeGenerator is used to generate a JSON file with generic constraint evaluation. The generated +/// file contains the data used for GPU acceleration. #[derive(Default, Debug)] pub struct CodeGenerator { num_polys: u16, @@ -33,18 +31,13 @@ pub struct CodeGenerator { impl CodeGenerator { pub fn new(ir: &AirIR, extension_degree: u8) -> Result { - // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array - let mut expressions_map = BTreeMap::new(); - let num_polys = set_num_polys(ir, extension_degree); let num_variables = set_num_variables(ir); let constants = set_constants(ir); - let mut expressions_handler = ExpressionsHandler::new(ir, &constants, &mut expressions_map); - - let mut expressions = expressions_handler.get_expressions()?; - // vector of `expressions` indexes - let outputs = expressions_handler.get_outputs(&mut expressions)?; + let mut gce_builder = GceBuilder::new(); + gce_builder.build(ir, &constants)?; + let (expressions, outputs) = gce_builder.into_gce()?; Ok(CodeGenerator { num_polys, @@ -173,13 +166,18 @@ fn set_constants(ir: &AirIR) -> Vec { constants } -/// Stroes node type required in [NodeReference] struct +/// Stores the node type required by the [NodeReference] struct. #[derive(Debug, Clone)] pub enum NodeType { + // Refers to the value in the trace column at the specified `index` in the current row. Pol, + // Refers to the value in the trace column at the specified `index` in the next row. PolNext, + // Refers to a public input or a random value at the specified `index`. Var, + // Refers to a constant at the specified `index`. Const, + // Refers to a previously defined expression at the specified index. Expr, } @@ -212,7 +210,8 @@ impl Display for ExpressionOperation { } } -/// Stores data used in JSON generation +/// Stores the reference to the node using the type of the node and index in related array of +/// nodes. #[derive(Debug, Clone)] pub struct NodeReference { pub node_type: NodeType, @@ -229,7 +228,8 @@ impl Display for NodeReference { } } -/// Stores data used in JSON generation +/// Stores the expression node using the expression operation and references to the left and rigth +/// nodes. #[derive(Clone, Debug)] pub struct ExpressionJson { pub op: ExpressionOperation,