diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index 4002b1cdc3..fc41ff6203 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -6,7 +6,10 @@ use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType}; use powdr_number::FieldElement; use crate::witgen::{ - jit::{processor::Processor, prover_function_heuristics::decode_prover_functions}, + jit::{ + processor::Processor, prover_function_heuristics::decode_prover_functions, + witgen_inference::Assignment, + }, machines::MachineParts, FixedData, }; @@ -61,25 +64,38 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { .enumerate() .filter_map(|(i, is_input)| is_input.then_some(Variable::Param(i))) .collect::>(); - let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables, []); + let witgen = WitgenInference::new(self.fixed_data, self, known_variables, []); let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?; // In the latch row, set the RHS selector to 1. + let mut assignments = vec![]; let selector = &connection.right.selector; - witgen.assign_constant(selector, self.latch_row as i32, T::one()); + assignments.push(Assignment::assign_constant( + selector, + self.latch_row as i32, + T::one(), + )); // Set all other selectors to 0 in the latch row. for other_connection in self.machine_parts.connections.values() { let other_selector = &other_connection.right.selector; if other_selector != selector { - witgen.assign_constant(other_selector, self.latch_row as i32, T::zero()); + assignments.push(Assignment::assign_constant( + other_selector, + self.latch_row as i32, + T::zero(), + )); } } // For each argument, connect the expression on the RHS with the formal parameter. for (index, expr) in connection.right.expressions.iter().enumerate() { - witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index)); + assignments.push(Assignment::assign_variable( + expr, + self.latch_row as i32, + Variable::Param(index), + )); } let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions(); @@ -124,6 +140,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { self.fixed_data, self, identities, + assignments, requested_known, BLOCK_MACHINE_MAX_BRANCH_DEPTH, ) @@ -313,10 +330,10 @@ params[2] = Add::c[0];" assert_eq!(c_rc, &RangeConstraint::from_mask(0xffffffffu64)); assert_eq!( format_code(&result.code), - "main_binary::sel[0][3] = 1; -main_binary::operation_id[3] = params[0]; + "main_binary::operation_id[3] = params[0]; main_binary::A[3] = params[1]; main_binary::B[3] = params[2]; +main_binary::sel[0][3] = 1; main_binary::operation_id[2] = main_binary::operation_id[3]; main_binary::operation_id[1] = main_binary::operation_id[2]; main_binary::operation_id[0] = main_binary::operation_id[1]; diff --git a/executor/src/witgen/jit/identity_queue.rs b/executor/src/witgen/jit/identity_queue.rs index 294019eae3..621688ef9d 100644 --- a/executor/src/witgen/jit/identity_queue.rs +++ b/executor/src/witgen/jit/identity_queue.rs @@ -10,7 +10,9 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -use crate::witgen::{data_structures::identity::Identity, FixedData}; +use crate::witgen::{ + data_structures::identity::Identity, jit::variable::MachineCallVariable, FixedData, +}; use super::{ variable::Variable, @@ -128,13 +130,25 @@ fn compute_occurrences_map<'b, 'a: 'b, T: FieldElement>( .flat_map(|item| { let variables = match item { QueueItem::Identity(id, row) => { - references_in_identity(id, fixed_data, &mut intermediate_cache) - .into_iter() + let mut variables = references_per_identity[&id.id()] + .iter() .map(|r| { let name = fixed_data.column_name(&r.poly_id).to_string(); Variable::from_reference(&r.with_name(name), *row) }) - .collect_vec() + .collect_vec(); + if let Identity::BusSend(bus_send) = id { + variables.extend((0..bus_send.selected_payload.expressions.len()).map( + |index| { + Variable::MachineCallParam(MachineCallVariable { + identity_id: id.id(), + row_offset: *row, + index, + }) + }, + )); + }; + variables } QueueItem::Assignment(a) => { variables_in_assignment(a, fixed_data, &mut intermediate_cache) @@ -152,9 +166,20 @@ fn references_in_identity( intermediate_cache: &mut HashMap>, ) -> Vec { let mut result = BTreeSet::new(); - for e in identity.children() { - result.extend(references_in_expression(e, fixed_data, intermediate_cache)); + + match identity { + Identity::BusSend(bus_send) => result.extend(references_in_expression( + &bus_send.selected_payload.selector, + fixed_data, + intermediate_cache, + )), + _ => { + for e in identity.children() { + result.extend(references_in_expression(e, fixed_data, intermediate_cache)); + } + } } + result.into_iter().collect() } diff --git a/executor/src/witgen/jit/processor.rs b/executor/src/witgen/jit/processor.rs index 1b0d5b3622..0ac4724bec 100644 --- a/executor/src/witgen/jit/processor.rs +++ b/executor/src/witgen/jit/processor.rs @@ -21,7 +21,9 @@ use super::{ identity_queue::{IdentityQueue, QueueItem}, prover_function_heuristics::ProverFunction, variable::{Cell, MachineCallVariable, Variable}, - witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference}, + witgen_inference::{ + Assignment, BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference, + }, }; /// A generic processor for generating JIT code. @@ -31,6 +33,8 @@ pub struct Processor<'a, T: FieldElement, FixedEval> { fixed_evaluator: FixedEval, /// List of identities and row offsets to process them on. identities: Vec<(&'a Identity, i32)>, + /// List of assignments provided from outside. + initial_assignments: Vec>, /// The prover functions, i.e. helpers to compute certain values that /// we cannot easily determine. prover_functions: Vec<(ProverFunction<'a, T>, i32)>, @@ -60,6 +64,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv fixed_data: &'a FixedData<'a, T>, fixed_evaluator: FixedEval, identities: impl IntoIterator, i32)>, + assignments: Vec>, requested_known_vars: impl IntoIterator, max_branch_depth: usize, ) -> Self { @@ -68,6 +73,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv fixed_data, fixed_evaluator, identities, + initial_assignments: assignments, prover_functions: vec![], block_size: 1, check_block_shape: false, @@ -111,23 +117,37 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv pub fn generate_code( self, can_process: impl CanProcessCall, - mut witgen: WitgenInference<'a, T, FixedEval>, + witgen: WitgenInference<'a, T, FixedEval>, ) -> Result, Error<'a, T, FixedEval>> { - // Create variables for bus send arguments. - for (id, row_offset) in &self.identities { - if let Identity::BusSend(bus_send) = id { - for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() { - let var = Variable::MachineCallParam(MachineCallVariable { - identity_id: bus_send.identity_id, - row_offset: *row_offset, - index, - }); - witgen.assign_variable(arg, *row_offset, var.clone()); - } - } - } + // Create variable assignments for bus send arguments. + let mut assignments = self.initial_assignments.clone(); + assignments.extend( + self.identities + .iter() + .filter_map(|(id, row_offset)| { + if let Identity::BusSend(bus_send) = id { + Some(( + bus_send.identity_id, + &bus_send.selected_payload.expressions, + *row_offset, + )) + } else { + None + } + }) + .flat_map(|(identity_id, arguments, row_offset)| { + arguments.iter().enumerate().map(move |(index, arg)| { + let var = Variable::MachineCallParam(MachineCallVariable { + identity_id, + row_offset, + index, + }); + Assignment::assign_variable(arg, row_offset, var) + }) + }), + ); let branch_depth = 0; - let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &[]); + let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &assignments); self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth) } @@ -296,11 +316,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv identity_queue: &mut IdentityQueue<'a, T>, ) -> Result<(), affine_symbolic_expression::Error> { loop { - let identity = identity_queue.next(); - let updated_vars = match identity { + let item = identity_queue.next(); + let updated_vars = match &item { Some(QueueItem::Identity(identity, row_offset)) => match identity { Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => { - witgen.process_polynomial_identity(*id, expression, row_offset) + witgen.process_polynomial_identity(*id, expression, *row_offset) } Identity::BusSend(BusSend { bus_id: _, @@ -311,23 +331,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv *identity_id, &selected_payload.selector, selected_payload.expressions.len(), - row_offset, + *row_offset, ), Identity::Connect(..) => Ok(vec![]), }, + Some(QueueItem::Assignment(assignment)) => witgen.process_assignment(assignment), // TODO Also add prover functions to the queue (activated by their variables) // and sort them so that they are always last. - Some(QueueItem::Assignment(_assignment)) => { - todo!() - } None => self.process_prover_functions(witgen), }?; - if updated_vars.is_empty() && identity.is_none() { + if updated_vars.is_empty() && item.is_none() { // No identities to process and prover functions did not make any progress, // we are done. return Ok(()); } - identity_queue.variables_updated(updated_vars, identity); + identity_queue.variables_updated(updated_vars, item); } } diff --git a/executor/src/witgen/jit/single_step_processor.rs b/executor/src/witgen/jit/single_step_processor.rs index aae13fb543..a5cac75dd8 100644 --- a/executor/src/witgen/jit/single_step_processor.rs +++ b/executor/src/witgen/jit/single_step_processor.rs @@ -81,6 +81,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { self.fixed_data, self, identities, + vec![], requested_known, SINGLE_STEP_MACHINE_MAX_BRANCH_DEPTH, ) @@ -237,9 +238,9 @@ namespace M(256); assert_eq!( format_code(&code), "\ -call_var(1, 0, 0) = VM::pc[0]; call_var(1, 0, 1) = VM::instr_add[0]; call_var(1, 0, 2) = VM::instr_mul[0]; +call_var(1, 0, 0) = VM::pc[0]; VM::pc[1] = (VM::pc[0] + 1); call_var(1, 1, 0) = VM::pc[1]; VM::B[1] = VM::B[0]; @@ -280,9 +281,9 @@ if (VM::instr_add[0] == 1) { assert_eq!( format_code(&code), "\ -call_var(2, 0, 0) = VM::pc[0]; call_var(2, 0, 1) = VM::instr_add[0]; call_var(2, 0, 2) = VM::instr_mul[0]; +call_var(2, 0, 0) = VM::pc[0]; VM::pc[1] = VM::pc[0]; call_var(2, 1, 0) = VM::pc[1]; VM::instr_add[1] = 0; diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index ae01e715c6..fdc1770318 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeSet, HashMap, HashSet}, + collections::{HashMap, HashSet}, fmt::{Display, Formatter}, }; @@ -42,8 +42,6 @@ pub struct WitgenInference<'a, T: FieldElement, FixedEval> { /// This mainly avoids generating multiple submachine calls for the same /// connection on the same row. complete_identities: HashSet<(u64, i32)>, - /// Internal equality constraints that are not identities from the constraint set. - assignments: BTreeSet>, code: Vec>, } @@ -88,7 +86,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F derived_range_constraints: Default::default(), known_variables: known_variables.into_iter().collect(), complete_identities: complete_identities.into_iter().collect(), - assignments: Default::default(), code: Default::default(), } } @@ -200,6 +197,15 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F self.ingest_effects(result, Some((lookup_id, row_offset))) } + pub fn process_assignment( + &mut self, + assignment: &Assignment<'a, T>, + ) -> Result, Error> { + let result = + self.process_equality_on_row(assignment.lhs, assignment.row_offset, &assignment.rhs)?; + self.ingest_effects(result, None) + } + /// Process a prover function on a row, i.e. determine if we can execute it and if it will /// help us to compute the value of previously unknown variables. /// Returns the list of updated variables. @@ -253,35 +259,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F ) } - /// Process the constraint that the expression evaluated at the given offset equals the given value. - /// This does not have to be solvable right away, but is always processed as soon as we have progress. - /// Note that all variables in the expression can be unknown and their status can also change over time. - pub fn assign_constant(&mut self, expression: &'a Expression, row_offset: i32, value: T) { - self.assignments.insert(Assignment { - lhs: expression, - row_offset, - rhs: VariableOrValue::Value(value), - }); - self.process_assignments().unwrap(); - } - - /// Process the constraint that the expression evaluated at the given offset equals the given formal variable. - /// This does not have to be solvable right away, but is always processed as soon as we have progress. - /// Note that all variables in the expression can be unknown and their status can also change over time. - pub fn assign_variable( - &mut self, - expression: &'a Expression, - row_offset: i32, - variable: Variable, - ) { - self.assignments.insert(Assignment { - lhs: expression, - row_offset, - rhs: VariableOrValue::Variable(variable), - }); - self.process_assignments().unwrap(); - } - /// Processes an equality constraint. /// If this returns an error, it means we have conflicting constraints. fn process_equality_on_row( @@ -332,7 +309,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F argument_count: usize, row_offset: i32, ) -> ProcessResult { - self.process_assignments().unwrap(); // We need to know the selector. let Some(selector) = self .evaluate(selector, row_offset) @@ -386,31 +362,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F } } - fn process_assignments(&mut self) -> Result, Error> { - let mut updated_variables = vec![]; - loop { - let mut progress = false; - // We need to take them out because ingest_effects needs a &mut self. - let assignments = std::mem::take(&mut self.assignments); - for assignment in &assignments { - let r = self.process_equality_on_row( - assignment.lhs, - assignment.row_offset, - &assignment.rhs, - )?; - let updated_vars = self.ingest_effects(r, None)?; - progress |= !updated_vars.is_empty(); - updated_variables.extend(updated_vars); - } - assert!(self.assignments.is_empty()); - self.assignments = assignments; - if !progress { - break; - } - } - Ok(updated_variables) - } - /// Analyze the effects and update the internal state. /// If the effect is the result of a machine call, `identity_id` must be given /// to avoid two calls to the same sub-machine on the same row. @@ -493,10 +444,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F self.complete_identities.insert(identity_id); } } - if !updated_variables.is_empty() { - // TODO we could have an occurrence map for the assignments as well. - updated_variables.extend(self.process_assignments()?); - } Ok(updated_variables) } @@ -677,6 +624,8 @@ fn is_known_zero(x: &Option { pub rhs: VariableOrValue, } +impl<'a, T: FieldElement> Assignment<'a, T> { + pub fn assign_constant(lhs: &'a Expression, row_offset: i32, rhs: T) -> Self { + Self { + lhs, + row_offset, + rhs: VariableOrValue::Value(rhs), + } + } + + pub fn assign_variable(lhs: &'a Expression, row_offset: i32, rhs: Variable) -> Self { + Self { + lhs, + row_offset, + rhs: VariableOrValue::Variable(rhs), + } + } +} + #[derive(Clone, derive_more::Display, Ord, PartialOrd, Eq, PartialEq, Debug)] pub enum VariableOrValue { Variable(V), Value(T), } +impl Display for Assignment<'_, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} = {} [row {}]", + self.lhs, + match &self.rhs { + VariableOrValue::Variable(v) => v.to_string(), + VariableOrValue::Value(v) => v.to_string(), + }, + self.row_offset + ) + } +} + pub trait FixedEvaluator: Clone { /// Evaluate a fixed column cell and returns its value if it is /// compile-time constant, otherwise return None. @@ -791,21 +773,6 @@ mod test { let ref_eval = FixedEvaluatorForFixedData(&fixed_data); let mut witgen = WitgenInference::new(&fixed_data, ref_eval, known_cells, []); let mut counter = 0; - // Create variables for bus send arguments. - for row in rows { - for id in &fixed_data.identities { - if let Identity::BusSend(bus_send) = id { - for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() { - let var = Variable::MachineCallParam(MachineCallVariable { - identity_id: bus_send.identity_id, - row_offset: *row, - index, - }); - witgen.assign_variable(arg, *row, var.clone()); - } - } - } - } loop { let mut progress = false; @@ -820,15 +787,37 @@ mod test { bus_id: _, identity_id, selected_payload, - }) => witgen - .process_call( - &mutable_state, - *identity_id, - &selected_payload.selector, - selected_payload.expressions.len(), - *row, - ) - .unwrap(), + }) => { + let mut updated_vars = vec![]; + for (index, arg) in selected_payload.expressions.iter().enumerate() { + let var = Variable::MachineCallParam(MachineCallVariable { + identity_id: *identity_id, + row_offset: *row, + index, + }); + updated_vars.extend( + witgen + .process_assignment(&Assignment { + lhs: arg, + row_offset: *row, + rhs: VariableOrValue::Variable(var), + }) + .unwrap(), + ); + } + updated_vars.extend( + witgen + .process_call( + &mutable_state, + *identity_id, + &selected_payload.selector, + selected_payload.expressions.len(), + *row, + ) + .unwrap(), + ); + updated_vars + } Identity::Connect(..) => vec![], }; progress |= !updated_vars.is_empty(); @@ -930,38 +919,38 @@ namespace Xor(256 * 256); Xor::A_byte[6] = ((Xor::A[7] & 0xff000000) // 16777216); Xor::A[6] = (Xor::A[7] & 0xffffff); assert (Xor::A[7] & 0xffffffff00000000) == 0; -call_var(0, 6, 0) = Xor::A_byte[6]; Xor::C_byte[6] = ((Xor::C[7] & 0xff000000) // 16777216); Xor::C[6] = (Xor::C[7] & 0xffffff); assert (Xor::C[7] & 0xffffffff00000000) == 0; -call_var(0, 6, 2) = Xor::C_byte[6]; Xor::A_byte[5] = ((Xor::A[6] & 0xff0000) // 65536); Xor::A[5] = (Xor::A[6] & 0xffff); assert (Xor::A[6] & 0xffffffffff000000) == 0; -call_var(0, 5, 0) = Xor::A_byte[5]; Xor::C_byte[5] = ((Xor::C[6] & 0xff0000) // 65536); Xor::C[5] = (Xor::C[6] & 0xffff); assert (Xor::C[6] & 0xffffffffff000000) == 0; -call_var(0, 5, 2) = Xor::C_byte[5]; +call_var(0, 6, 0) = Xor::A_byte[6]; +call_var(0, 6, 2) = Xor::C_byte[6]; machine_call(0, [Known(call_var(0, 6, 0)), Unknown(call_var(0, 6, 1)), Known(call_var(0, 6, 2))]); -Xor::B_byte[6] = call_var(0, 6, 1); Xor::A_byte[4] = ((Xor::A[5] & 0xff00) // 256); Xor::A[4] = (Xor::A[5] & 0xff); assert (Xor::A[5] & 0xffffffffffff0000) == 0; -call_var(0, 4, 0) = Xor::A_byte[4]; Xor::C_byte[4] = ((Xor::C[5] & 0xff00) // 256); Xor::C[4] = (Xor::C[5] & 0xff); assert (Xor::C[5] & 0xffffffffffff0000) == 0; -call_var(0, 4, 2) = Xor::C_byte[4]; +call_var(0, 5, 0) = Xor::A_byte[5]; +call_var(0, 5, 2) = Xor::C_byte[5]; machine_call(0, [Known(call_var(0, 5, 0)), Unknown(call_var(0, 5, 1)), Known(call_var(0, 5, 2))]); -Xor::B_byte[5] = call_var(0, 5, 1); +Xor::B_byte[6] = call_var(0, 6, 1); Xor::A_byte[3] = Xor::A[4]; -call_var(0, 3, 0) = Xor::A_byte[3]; Xor::C_byte[3] = Xor::C[4]; -call_var(0, 3, 2) = Xor::C_byte[3]; +call_var(0, 4, 0) = Xor::A_byte[4]; +call_var(0, 4, 2) = Xor::C_byte[4]; machine_call(0, [Known(call_var(0, 4, 0)), Unknown(call_var(0, 4, 1)), Known(call_var(0, 4, 2))]); -Xor::B_byte[4] = call_var(0, 4, 1); +Xor::B_byte[5] = call_var(0, 5, 1); +call_var(0, 3, 0) = Xor::A_byte[3]; +call_var(0, 3, 2) = Xor::C_byte[3]; machine_call(0, [Known(call_var(0, 3, 0)), Unknown(call_var(0, 3, 1)), Known(call_var(0, 3, 2))]); +Xor::B_byte[4] = call_var(0, 4, 1); Xor::B_byte[3] = call_var(0, 3, 1); Xor::B[4] = Xor::B_byte[3]; Xor::B[5] = (Xor::B[4] + (Xor::B_byte[4] * 256));