diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 54181921426..e205750a176 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -1,5 +1,5 @@ use acir::{ - brillig::{ForeignCallParam, RegisterIndex, Value}, + brillig::{ForeignCallParam, ForeignCallResult, RegisterIndex, Value}, circuit::{ brillig::{Brillig, BrilligInputs, BrilligOutputs}, OpcodeLocation, @@ -21,16 +21,17 @@ pub(super) enum BrilligSolverStatus { } pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { - witness: &'b mut WitnessMap, - brillig: &'b Brillig, - acir_index: usize, vm: VM<'b, B>, + acir_index: usize, } impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { - pub(super) fn build_or_skip( - initial_witness: &'b mut WitnessMap, - brillig: &'b Brillig, + /// Constructs a solver for a Brillig block given the bytecode and initial + /// witness. If the block should be skipped entirely because its predicate + /// evaluates to false, zero out the block outputs and return Ok(None). + pub(super) fn build_or_skip<'w>( + initial_witness: &'w mut WitnessMap, + brillig: &'w Brillig, bb_solver: &'b B, acir_index: usize, ) -> Result, OpcodeResolutionError> { @@ -39,18 +40,11 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { return Ok(None); } - let vm = Self::setup_vm(initial_witness, brillig, bb_solver)?; - Ok(Some( - Self { - witness: initial_witness, - brillig, - acir_index, - vm, - } - )) + let vm = Self::build_vm(initial_witness, brillig, bb_solver)?; + Ok(Some(Self { vm, acir_index })) } - fn should_skip(witness: &mut WitnessMap, brillig: &Brillig) -> Result { + fn should_skip(witness: &WitnessMap, brillig: &Brillig) -> Result { // If the predicate is `None`, then we simply return the value 1 // If the predicate is `Some` but we cannot find a value, then we return stalled let pred_value = match &brillig.predicate { @@ -82,8 +76,8 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(()) } - fn setup_vm( - witness: &mut WitnessMap, + fn build_vm( + witness: &WitnessMap, brillig: &Brillig, bb_solver: &'b B, ) -> Result, OpcodeResolutionError> { @@ -137,24 +131,21 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { } pub(super) fn solve(&mut self) -> Result { - // Run the Brillig VM on these inputs, bytecode, etc! - while matches!(self.vm.process_opcode(), VMStatus::InProgress) {} - - self.finish_execution() + let status = self.vm.process_opcodes(); + self.handle_vm_status(status) } - pub(super) fn finish_execution(&mut self) -> Result { - // Check the status of the Brillig VM. + fn handle_vm_status( + &self, + vm_status: VMStatus, + ) -> Result { + // Check the status of the Brillig VM and return a resolution. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls // (when it gets foreign call results for example). - let vm_status = self.vm.get_status(); match vm_status { - VMStatus::Finished => { - self.write_brillig_outputs()?; - Ok(BrilligSolverStatus::Finished) - } - VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"), + VMStatus::Finished => Ok(BrilligSolverStatus::Finished), + VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress), VMStatus::Failure { message, call_stack } => { Err(OpcodeResolutionError::BrilligFunctionFailed { message, @@ -173,25 +164,52 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { } } - fn write_brillig_outputs(&mut self) -> Result<(), OpcodeResolutionError> { + pub(super) fn finalize( + self, + witness: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + // Finish the Brillig execution by writing the outputs to the witness map + let vm_status = self.vm.get_status(); + match vm_status { + VMStatus::Finished => { + self.write_brillig_outputs(witness, brillig)?; + Ok(()) + } + _ => panic!("Brillig VM has not completed execution"), + } + } + + fn write_brillig_outputs( + &self, + witness_map: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { // Write VM execution results into the witness map - for (i, output) in self.brillig.outputs.iter().enumerate() { + for (i, output) in brillig.outputs.iter().enumerate() { let register_value = self.vm.get_registers().get(RegisterIndex::from(i)); match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, register_value.to_field(), self.witness)?; + insert_value(witness, register_value.to_field(), witness_map)?; } BrilligOutputs::Array(witness_arr) => { // Treat the register value as a pointer to memory for (i, witness) in witness_arr.iter().enumerate() { let value = &self.vm.get_memory()[register_value.to_usize() + i]; - insert_value(witness, value.to_field(), self.witness)?; + insert_value(witness, value.to_field(), witness_map)?; } } } } Ok(()) } + + pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + match self.vm.get_status() { + VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result), + _ => unreachable!("Brillig VM is not waiting for a foreign call"), + } + } } /// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall] diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 7fc94433da8..532e7fbd0e0 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -11,7 +11,9 @@ use acir::{ use acvm_blackbox_solver::BlackBoxResolutionError; use self::{ - arithmetic::ArithmeticSolver, brillig::{BrilligSolver, BrilligSolverStatus}, directives::solve_directives, + arithmetic::ArithmeticSolver, + brillig::{BrilligSolver, BrilligSolverStatus}, + directives::solve_directives, memory_op::MemoryOpSolver, }; use crate::{BlackBoxFunctionSolver, Language}; @@ -140,6 +142,8 @@ pub struct ACVM<'backend, B: BlackBoxFunctionSolver> { instruction_pointer: usize, witness_map: WitnessMap, + + brillig_solver: Option>, } impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { @@ -152,6 +156,7 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { opcodes, instruction_pointer: 0, witness_map: initial_witness, + brillig_solver: None, } } @@ -216,12 +221,8 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { panic!("ACVM is not expecting a foreign call response as no call was made"); } - // We want to inject the foreign call result into the brillig opcode which initiated the call. - let opcode = &mut self.opcodes[self.instruction_pointer]; - let Opcode::Brillig(brillig) = opcode else { - unreachable!("ACVM can only enter `RequiresForeignCall` state on a Brillig opcode"); - }; - brillig.foreign_call_results.push(foreign_call_result); + let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver"); + brillig_solver.resolve_pending_foreign_call(foreign_call_result); // Now that the foreign call has been resolved then we can resume execution. self.status(ACVMStatus::InProgress); @@ -258,22 +259,36 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { solver.solve_memory_op(op, &mut self.witness_map, predicate) } Opcode::Brillig(brillig) => { - let result = BrilligSolver::build_or_skip( - &mut self.witness_map, - brillig, - self.backend, - self.instruction_pointer, - ); - match result { - Ok(Some(mut solver)) => { - match solver.solve() { - Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => - return self.wait_for_foreign_call(foreign_call), - Ok(BrilligSolverStatus::InProgress) => - unreachable!("Brillig solver still in progress"), - res => res.map(|_| ()), + let witness = &mut self.witness_map; + // get the active Brillig solver, or try to build one if necessary + // (Brillig execution maybe bypassed by constraints) + let maybe_solver = match self.brillig_solver.as_mut() { + Some(solver) => Ok(Some(solver)), + None => BrilligSolver::build_or_skip( + witness, + brillig, + self.backend, + self.instruction_pointer, + ) + .and_then(|optional_solver| { + Ok(optional_solver + .and_then(|solver| Some(self.brillig_solver.insert(solver)))) + }), + }; + match maybe_solver { + Ok(Some(solver)) => match solver.solve() { + Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => { + return self.wait_for_foreign_call(foreign_call); } - } + Ok(BrilligSolverStatus::InProgress) => { + unreachable!("Brillig solver still in progress") + } + Ok(BrilligSolverStatus::Finished) => { + // clear active Brillig solver and write execution outputs + self.brillig_solver.take().unwrap().finalize(witness, brillig) + } + res => res.map(|_| ()), + }, res => res.map(|_| ()), } }