Skip to content

Commit

Permalink
Move assignments to processor (#2454)
Browse files Browse the repository at this point in the history
This moves the handling of assignments from witgen_inference to the
processor and also uses the queue to schedule assignments.
  • Loading branch information
chriseth authored Feb 7, 2025
1 parent f198341 commit d7fdda3
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 138 deletions.
31 changes: 24 additions & 7 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType};
use powdr_number::FieldElement;

use crate::witgen::{
jit::{processor::Processor, prover_function_heuristics::decode_prover_functions},
jit::{
processor::Processor, prover_function_heuristics::decode_prover_functions,
witgen_inference::Assignment,
},
machines::MachineParts,
FixedData,
};
Expand Down Expand Up @@ -61,25 +64,38 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.enumerate()
.filter_map(|(i, is_input)| is_input.then_some(Variable::Param(i)))
.collect::<HashSet<_>>();
let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables, []);
let witgen = WitgenInference::new(self.fixed_data, self, known_variables, []);

let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?;

// In the latch row, set the RHS selector to 1.
let mut assignments = vec![];
let selector = &connection.right.selector;
witgen.assign_constant(selector, self.latch_row as i32, T::one());
assignments.push(Assignment::assign_constant(
selector,
self.latch_row as i32,
T::one(),
));

// Set all other selectors to 0 in the latch row.
for other_connection in self.machine_parts.connections.values() {
let other_selector = &other_connection.right.selector;
if other_selector != selector {
witgen.assign_constant(other_selector, self.latch_row as i32, T::zero());
assignments.push(Assignment::assign_constant(
other_selector,
self.latch_row as i32,
T::zero(),
));
}
}

// For each argument, connect the expression on the RHS with the formal parameter.
for (index, expr) in connection.right.expressions.iter().enumerate() {
witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index));
assignments.push(Assignment::assign_variable(
expr,
self.latch_row as i32,
Variable::Param(index),
));
}

let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();
Expand Down Expand Up @@ -124,6 +140,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
self.fixed_data,
self,
identities,
assignments,
requested_known,
BLOCK_MACHINE_MAX_BRANCH_DEPTH,
)
Expand Down Expand Up @@ -313,10 +330,10 @@ params[2] = Add::c[0];"
assert_eq!(c_rc, &RangeConstraint::from_mask(0xffffffffu64));
assert_eq!(
format_code(&result.code),
"main_binary::sel[0][3] = 1;
main_binary::operation_id[3] = params[0];
"main_binary::operation_id[3] = params[0];
main_binary::A[3] = params[1];
main_binary::B[3] = params[2];
main_binary::sel[0][3] = 1;
main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::operation_id[0] = main_binary::operation_id[1];
Expand Down
37 changes: 31 additions & 6 deletions executor/src/witgen/jit/identity_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use powdr_ast::{
};
use powdr_number::FieldElement;

use crate::witgen::{data_structures::identity::Identity, FixedData};
use crate::witgen::{
data_structures::identity::Identity, jit::variable::MachineCallVariable, FixedData,
};

use super::{
variable::Variable,
Expand Down Expand Up @@ -128,13 +130,25 @@ fn compute_occurrences_map<'b, 'a: 'b, T: FieldElement>(
.flat_map(|item| {
let variables = match item {
QueueItem::Identity(id, row) => {
references_in_identity(id, fixed_data, &mut intermediate_cache)
.into_iter()
let mut variables = references_per_identity[&id.id()]
.iter()
.map(|r| {
let name = fixed_data.column_name(&r.poly_id).to_string();
Variable::from_reference(&r.with_name(name), *row)
})
.collect_vec()
.collect_vec();
if let Identity::BusSend(bus_send) = id {
variables.extend((0..bus_send.selected_payload.expressions.len()).map(
|index| {
Variable::MachineCallParam(MachineCallVariable {
identity_id: id.id(),
row_offset: *row,
index,
})
},
));
};
variables
}
QueueItem::Assignment(a) => {
variables_in_assignment(a, fixed_data, &mut intermediate_cache)
Expand All @@ -152,9 +166,20 @@ fn references_in_identity<T: FieldElement>(
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> Vec<AlgebraicReferenceThin> {
let mut result = BTreeSet::new();
for e in identity.children() {
result.extend(references_in_expression(e, fixed_data, intermediate_cache));

match identity {
Identity::BusSend(bus_send) => result.extend(references_in_expression(
&bus_send.selected_payload.selector,
fixed_data,
intermediate_cache,
)),
_ => {
for e in identity.children() {
result.extend(references_in_expression(e, fixed_data, intermediate_cache));
}
}
}

result.into_iter().collect()
}

Expand Down
68 changes: 43 additions & 25 deletions executor/src/witgen/jit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use super::{
identity_queue::{IdentityQueue, QueueItem},
prover_function_heuristics::ProverFunction,
variable::{Cell, MachineCallVariable, Variable},
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference},
witgen_inference::{
Assignment, BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference,
},
};

/// A generic processor for generating JIT code.
Expand All @@ -31,6 +33,8 @@ pub struct Processor<'a, T: FieldElement, FixedEval> {
fixed_evaluator: FixedEval,
/// List of identities and row offsets to process them on.
identities: Vec<(&'a Identity<T>, i32)>,
/// List of assignments provided from outside.
initial_assignments: Vec<Assignment<'a, T>>,
/// The prover functions, i.e. helpers to compute certain values that
/// we cannot easily determine.
prover_functions: Vec<(ProverFunction<'a, T>, i32)>,
Expand Down Expand Up @@ -60,6 +64,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
fixed_data: &'a FixedData<'a, T>,
fixed_evaluator: FixedEval,
identities: impl IntoIterator<Item = (&'a Identity<T>, i32)>,
assignments: Vec<Assignment<'a, T>>,
requested_known_vars: impl IntoIterator<Item = Variable>,
max_branch_depth: usize,
) -> Self {
Expand All @@ -68,6 +73,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
fixed_data,
fixed_evaluator,
identities,
initial_assignments: assignments,
prover_functions: vec![],
block_size: 1,
check_block_shape: false,
Expand Down Expand Up @@ -111,23 +117,37 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
pub fn generate_code(
self,
can_process: impl CanProcessCall<T>,
mut witgen: WitgenInference<'a, T, FixedEval>,
witgen: WitgenInference<'a, T, FixedEval>,
) -> Result<ProcessorResult<T>, Error<'a, T, FixedEval>> {
// Create variables for bus send arguments.
for (id, row_offset) in &self.identities {
if let Identity::BusSend(bus_send) = id {
for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id: bus_send.identity_id,
row_offset: *row_offset,
index,
});
witgen.assign_variable(arg, *row_offset, var.clone());
}
}
}
// Create variable assignments for bus send arguments.
let mut assignments = self.initial_assignments.clone();
assignments.extend(
self.identities
.iter()
.filter_map(|(id, row_offset)| {
if let Identity::BusSend(bus_send) = id {
Some((
bus_send.identity_id,
&bus_send.selected_payload.expressions,
*row_offset,
))
} else {
None
}
})
.flat_map(|(identity_id, arguments, row_offset)| {
arguments.iter().enumerate().map(move |(index, arg)| {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id,
row_offset,
index,
});
Assignment::assign_variable(arg, row_offset, var)
})
}),
);
let branch_depth = 0;
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &[]);
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &assignments);
self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth)
}

Expand Down Expand Up @@ -296,11 +316,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
identity_queue: &mut IdentityQueue<'a, T>,
) -> Result<(), affine_symbolic_expression::Error> {
loop {
let identity = identity_queue.next();
let updated_vars = match identity {
let item = identity_queue.next();
let updated_vars = match &item {
Some(QueueItem::Identity(identity, row_offset)) => match identity {
Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => {
witgen.process_polynomial_identity(*id, expression, row_offset)
witgen.process_polynomial_identity(*id, expression, *row_offset)
}
Identity::BusSend(BusSend {
bus_id: _,
Expand All @@ -311,23 +331,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
*identity_id,
&selected_payload.selector,
selected_payload.expressions.len(),
row_offset,
*row_offset,
),
Identity::Connect(..) => Ok(vec![]),
},
Some(QueueItem::Assignment(assignment)) => witgen.process_assignment(assignment),
// TODO Also add prover functions to the queue (activated by their variables)
// and sort them so that they are always last.
Some(QueueItem::Assignment(_assignment)) => {
todo!()
}
None => self.process_prover_functions(witgen),
}?;
if updated_vars.is_empty() && identity.is_none() {
if updated_vars.is_empty() && item.is_none() {
// No identities to process and prover functions did not make any progress,
// we are done.
return Ok(());
}
identity_queue.variables_updated(updated_vars, identity);
identity_queue.variables_updated(updated_vars, item);
}
}

Expand Down
5 changes: 3 additions & 2 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
self.fixed_data,
self,
identities,
vec![],
requested_known,
SINGLE_STEP_MACHINE_MAX_BRANCH_DEPTH,
)
Expand Down Expand Up @@ -237,9 +238,9 @@ namespace M(256);
assert_eq!(
format_code(&code),
"\
call_var(1, 0, 0) = VM::pc[0];
call_var(1, 0, 1) = VM::instr_add[0];
call_var(1, 0, 2) = VM::instr_mul[0];
call_var(1, 0, 0) = VM::pc[0];
VM::pc[1] = (VM::pc[0] + 1);
call_var(1, 1, 0) = VM::pc[1];
VM::B[1] = VM::B[0];
Expand Down Expand Up @@ -280,9 +281,9 @@ if (VM::instr_add[0] == 1) {
assert_eq!(
format_code(&code),
"\
call_var(2, 0, 0) = VM::pc[0];
call_var(2, 0, 1) = VM::instr_add[0];
call_var(2, 0, 2) = VM::instr_mul[0];
call_var(2, 0, 0) = VM::pc[0];
VM::pc[1] = VM::pc[0];
call_var(2, 1, 0) = VM::pc[1];
VM::instr_add[1] = 0;
Expand Down
Loading

0 comments on commit d7fdda3

Please sign in to comment.