From d7e0633200099f1c36d9d7bd9b1e808b3d4ee7a0 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Fri, 27 Sep 2024 10:22:14 +0100 Subject: [PATCH 1/4] feat(scheduler): add task coarsening policies and coarse grained task execution capability The task scheduling is controlled by the FHEVM_DF_SCHEDULE environment variable with available policies: - MAX_PARALLELISM: preserves parallelism but aggregates dependent tasks to take advantage of spatial and temporal cache locality. - MAX_LOCALITY: aggregates all connected components in the DFG to improve cache locality. - FINE_GRAIN: disable task aggregation and executes each FHE operation based on DFG dependences. --- fhevm-engine/executor/src/dfg.rs | 64 +- fhevm-engine/executor/src/dfg/scheduler.rs | 333 ++++++- fhevm-engine/executor/src/dfg/types.rs | 5 +- .../executor/tests/scheduling_mapping.rs | 180 ++++ .../executor/tests/scheduling_patterns.rs | 840 ++++++++++++++++++ fhevm-engine/executor/tests/sync_compute.rs | 538 ----------- 6 files changed, 1361 insertions(+), 599 deletions(-) create mode 100644 fhevm-engine/executor/tests/scheduling_mapping.rs create mode 100644 fhevm-engine/executor/tests/scheduling_patterns.rs diff --git a/fhevm-engine/executor/src/dfg.rs b/fhevm-engine/executor/src/dfg.rs index a10de529..3fad997e 100644 --- a/fhevm-engine/executor/src/dfg.rs +++ b/fhevm-engine/executor/src/dfg.rs @@ -15,19 +15,29 @@ use tfhe::integer::U256; use daggy::{Dag, NodeIndex}; use std::collections::HashMap; -//TODO#[derive(Debug)] -pub struct Node<'a> { - computation: &'a SyncComputation, +pub struct OpNode { + opcode: i32, result: DFGTaskResult, result_handle: Handle, inputs: Vec, } -pub type Edge = u8; +pub type OpEdge = u8; -//TODO#[derive(Debug)] -#[derive(Default)] +impl std::fmt::Debug for OpNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpNode") + .field("OP", &self.opcode) + .field( + "Result", + &format_args!("{0:?} (0x{0:X})", &self.result_handle[0]), + ) + .finish() + } +} + +#[derive(Default, Debug)] pub struct DFGraph<'a> { - pub graph: Dag, Edge>, + pub graph: Dag, produced_handles: HashMap<&'a Handle, NodeIndex>, } @@ -42,8 +52,8 @@ impl<'a> DFGraph<'a> { .first() .filter(|h| h.len() == HANDLE_LEN) .ok_or(SyncComputeError::BadResultHandles)?; - Ok(self.graph.add_node(Node { - computation, + Ok(self.graph.add_node(OpNode { + opcode: computation.operation, result: None, result_handle: rh.clone(), inputs, @@ -54,7 +64,7 @@ impl<'a> DFGraph<'a> { &mut self, source: NodeIndex, destination: NodeIndex, - consumer_input: Edge, + consumer_input: OpEdge, ) -> Result<(), SyncComputeError> { let _edge = self .graph @@ -78,7 +88,7 @@ impl<'a> DFGraph<'a> { if let Some(ct) = state.ciphertexts.get(h) { Ok(DFGTaskInput::Val(ct.expanded.clone())) } else { - Ok(DFGTaskInput::Handle(h.clone())) + Ok(DFGTaskInput::Dep(None)) } } Input::Scalar(s) if s.len() == SCALAR_LEN => { @@ -103,31 +113,23 @@ impl<'a> DFGraph<'a> { ); } } - // Traverse nodes and add dependences/edges as required - for index in 0..self.graph.node_count() { - let take_inputs = std::mem::take( - &mut self - .graph - .node_weight_mut(NodeIndex::new(index)) - .unwrap() - .inputs, - ); - for (idx, input) in take_inputs.iter().enumerate() { - match input { - DFGTaskInput::Handle(input) => { + // Traverse computations and add dependences/edges as required + for (index, computation) in req.computations.iter().enumerate() { + for (input_idx, input) in computation.inputs.iter().enumerate() { + if let Some(Input::Handle(input)) = &input.input { + if !state.ciphertexts.contains_key(input) { if let Some(producer_index) = self.produced_handles.get(input) { - self.add_dependence(*producer_index, NodeIndex::new(index), idx as u8)?; + let consumer_index = NodeIndex::new(index); + self.graph[consumer_index].inputs[input_idx] = + DFGTaskInput::Dep(Some((*producer_index).index())); + self.add_dependence(*producer_index, consumer_index, input_idx as u8)?; + } else { + return Err(SyncComputeError::UnsatisfiedDependence); } } - DFGTaskInput::Val(_) => {} - }; + } } - self.graph - .node_weight_mut(NodeIndex::new(index)) - .unwrap() - .inputs = take_inputs; } - Ok(()) } diff --git a/fhevm-engine/executor/src/dfg/scheduler.rs b/fhevm-engine/executor/src/dfg/scheduler.rs index d92e43bf..f0d49fea 100644 --- a/fhevm-engine/executor/src/dfg/scheduler.rs +++ b/fhevm-engine/executor/src/dfg/scheduler.rs @@ -1,6 +1,14 @@ -use crate::dfg::{types::DFGTaskInput, Edge, Node}; +use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; + +use crate::dfg::types::*; +use crate::dfg::{OpEdge, OpNode}; use crate::server::{run_computation, InMemoryCiphertext, SyncComputeError}; use anyhow::Result; +use daggy::petgraph::csr::IndexType; +use daggy::petgraph::graph::node_index; +use daggy::petgraph::visit::{IntoEdgeReferences, IntoNeighbors, VisitMap, Visitable}; +use daggy::petgraph::Direction::Incoming; use fhevm_engine_common::types::SupportedFheCiphertexts; use daggy::{ @@ -12,47 +20,96 @@ use daggy::{ }; use tokio::task::JoinSet; -pub struct Scheduler<'a, 'b> { - graph: &'b mut Dag, Edge>, - edges: Dag<(), Edge>, - set: JoinSet>, +struct ExecNode { + df_nodes: Vec, + dependence_counter: AtomicUsize, +} + +pub enum PartitionStrategy { + MaxParallelism, + MaxLocality, } -impl<'a, 'b> Scheduler<'a, 'b> { - fn is_ready(node: &Node<'a>) -> bool { +impl std::fmt::Debug for ExecNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.df_nodes.is_empty() { + write!(f, "Vec [ ]") + } else { + let _ = write!(f, "Vec [ "); + for i in self.df_nodes.iter() { + let _ = write!(f, "{}, ", i.index()); + } + write!(f, "] - dependences: {:?}", self.dependence_counter) + } + } +} + +pub struct Scheduler<'a> { + graph: &'a mut Dag, + edges: Dag<(), OpEdge>, +} + +impl<'a> Scheduler<'a> { + fn is_ready(node: &OpNode) -> bool { let mut ready = true; for i in node.inputs.iter() { - if let DFGTaskInput::Handle(_) = i { + if let DFGTaskInput::Dep(_) = i { ready = false; } } ready } - pub fn new(graph: &'b mut Dag, Edge>) -> Self { - let mut set = JoinSet::new(); - for idx in 0..graph.node_count() { + fn is_ready_task(&self, node: &ExecNode) -> bool { + node.dependence_counter + .load(std::sync::atomic::Ordering::SeqCst) + == 0 + } + pub fn new(graph: &'a mut Dag) -> Self { + let edges = graph.map(|_, _| (), |_, edge| *edge); + Self { graph, edges } + } + + pub async fn schedule(&mut self) -> Result<(), SyncComputeError> { + let schedule_type = std::env::var("FHEVM_DF_SCHEDULE"); + match schedule_type { + Ok(val) if val == "MAX_PARALLELISM" => { + self.schedule_coarse_grain(PartitionStrategy::MaxParallelism) + .await + } + Ok(val) if val == "MAX_LOCALITY" => { + self.schedule_coarse_grain(PartitionStrategy::MaxLocality) + .await + } + Ok(val) if val == "LOOP" => panic!("Unimplemented LOOP scheduling strategy"), + Ok(val) if val == "FINE_GRAIN" => self.schedule_fine_grain().await, + Ok(unhandled) => panic!("Scheduling strategy {:?} does not exist", unhandled), + + _ => self.schedule_fine_grain().await, + } + } + + async fn schedule_fine_grain(&mut self) -> Result<(), SyncComputeError> { + let mut set: JoinSet> = + JoinSet::new(); + // Prime the scheduler with all nodes without dependences + for idx in 0..self.graph.node_count() { let index = NodeIndex::new(idx); - let node = graph.node_weight_mut(index).unwrap(); + let node = self.graph.node_weight_mut(index).unwrap(); if Self::is_ready(node) { - let opc = node.computation.operation; + let opcode = node.opcode; let inputs: Result, SyncComputeError> = node .inputs .iter() .map(|i| match i { DFGTaskInput::Val(i) => Ok(i.clone()), - DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), + _ => Err(SyncComputeError::ComputationFailed), }) .collect(); - set.spawn_blocking(move || run_computation(opc, inputs, idx)); + set.spawn_blocking(move || run_computation(opcode, inputs, idx)); } } - - let edges = graph.map(|_, _| (), |_, edge| *edge); - - Self { graph, edges, set } - } - pub async fn schedule(&mut self) -> Result<(), SyncComputeError> { - while let Some(result) = self.set.join_next().await { + // Get results from computations and update dependences of remaining computations + while let Some(result) = set.join_next().await { let output = result.map_err(|_| SyncComputeError::ComputationFailed)??; let index = output.0; let node_index = NodeIndex::new(index); @@ -63,21 +120,243 @@ impl<'a, 'b> Scheduler<'a, 'b> { child_node.inputs[*edge.weight() as usize] = DFGTaskInput::Val(output.1.expanded.clone()); if Self::is_ready(child_node) { - let opc = child_node.computation.operation; + let opcode = child_node.opcode; let inputs: Result, SyncComputeError> = child_node .inputs .iter() .map(|i| match i { DFGTaskInput::Val(i) => Ok(i.clone()), - DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), + _ => Err(SyncComputeError::ComputationFailed), }) .collect(); - self.set - .spawn_blocking(move || run_computation(opc, inputs, child_index.index())); + set.spawn_blocking(move || { + run_computation(opcode, inputs, child_index.index()) + }); } } - self.graph.node_weight_mut(node_index).unwrap().result = Some(output.1); + self.graph[node_index].result = Some(output.1); } Ok(()) } + + async fn schedule_coarse_grain( + &mut self, + strategy: PartitionStrategy, + ) -> Result<(), SyncComputeError> { + let mut set: JoinSet< + Result<(Vec<(usize, InMemoryCiphertext)>, NodeIndex), SyncComputeError>, + > = JoinSet::new(); + let mut execution_graph: Dag = Dag::default(); + match strategy { + PartitionStrategy::MaxLocality => { + let _ = partition_components(self.graph, &mut execution_graph); + } + PartitionStrategy::MaxParallelism => { + let _ = partition_preserving_parallelism(self.graph, &mut execution_graph); + } + } + let task_dependences = execution_graph.map(|_, _| (), |_, edge| *edge); + + // Prime the scheduler with all nodes without dependences + for idx in 0..execution_graph.node_count() { + let index = NodeIndex::new(idx); + let node = execution_graph.node_weight_mut(index).unwrap(); + if self.is_ready_task(node) { + let mut args = Vec::with_capacity(node.df_nodes.len()); + for nidx in node.df_nodes.iter() { + let n = self.graph.node_weight_mut(*nidx).unwrap(); + let opcode = n.opcode; + args.push((opcode, std::mem::take(&mut n.inputs), *nidx)); + } + set.spawn_blocking(move || execute_partition(args, index)); + } + } + // Get results from computations and update dependences of remaining computations + while let Some(result) = set.join_next().await { + let mut output = result.map_err(|_| SyncComputeError::ComputationFailed)??; + let task_index = output.1; + while let Some(o) = output.0.pop() { + let index = o.0; + let node_index = NodeIndex::new(index); + // Satisfy deps from the executed computation in the DFG + for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { + let child_index = edge.target(); + let child_node = self.graph.node_weight_mut(child_index).unwrap(); + if !child_node.inputs.is_empty() { + child_node.inputs[*edge.weight() as usize] = + DFGTaskInput::Val(o.1.expanded.clone()); + } + } + self.graph[node_index].result = Some(o.1); + } + for edge in task_dependences.edges_directed(task_index, Direction::Outgoing) { + let dependent_task_index = edge.target(); + let dependent_task = execution_graph + .node_weight_mut(dependent_task_index) + .unwrap(); + dependent_task + .dependence_counter + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + if self.is_ready_task(dependent_task) { + let mut args = Vec::with_capacity(dependent_task.df_nodes.len()); + for nidx in dependent_task.df_nodes.iter() { + let n = self.graph.node_weight_mut(*nidx).unwrap(); + let opcode = n.opcode; + args.push((opcode, std::mem::take(&mut n.inputs), *nidx)); + } + set.spawn_blocking(move || execute_partition(args, dependent_task_index)); + } + } + } + Ok(()) + } +} + +fn add_execution_depedences( + graph: &Dag, + execution_graph: &mut Dag, + node_map: HashMap, +) -> Result<(), SyncComputeError> { + // Once the DFG is partitioned, we need to add dependences as + // edges in the execution graph + for edge in graph.edge_references() { + let (xsrc, xdst) = ( + node_map + .get(&edge.source()) + .ok_or(SyncComputeError::BadInputs)?, + node_map + .get(&edge.target()) + .ok_or(SyncComputeError::BadInputs)?, + ); + if xsrc != xdst && execution_graph.find_edge(*xsrc, *xdst).is_none() { + let _ = execution_graph.add_edge(*xsrc, *xdst, ()); + } + } + for node in 0..execution_graph.node_count() { + let deps = execution_graph + .edges_directed(node_index(node), Incoming) + .count(); + execution_graph[node_index(node)] + .dependence_counter + .store(deps, std::sync::atomic::Ordering::SeqCst); + } + Ok(()) +} + +fn partition_preserving_parallelism( + graph: &Dag, + execution_graph: &mut Dag, +) -> Result<(), SyncComputeError> { + // First sort the DAG in a schedulable order + let ts = daggy::petgraph::algo::toposort(graph, None) + .map_err(|_| SyncComputeError::UnsatisfiedDependence)?; + let mut vis = graph.visit_map(); + let mut node_map = HashMap::new(); + // Traverse the DAG and build a graph of connected components + // without siblings (i.e. without parallelism) + for nidx in ts.iter() { + if !vis.is_visited(nidx) { + vis.visit(*nidx); + let mut df_nodes = vec![*nidx]; + let mut stack = vec![*nidx]; + while let Some(n) = stack.pop() { + if graph.edges_directed(n, Direction::Outgoing).count() == 1 { + for child in graph.neighbors(n) { + if !vis.is_visited(&child.index()) + && graph.edges_directed(child, Direction::Incoming).count() == 1 + { + df_nodes.push(child); + stack.push(child); + vis.visit(child.index()); + } + } + } + } + let ex_node = execution_graph.add_node(ExecNode { + df_nodes: vec![], + dependence_counter: AtomicUsize::new(usize::MAX), + }); + for n in df_nodes.iter() { + node_map.insert(*n, ex_node); + } + execution_graph[ex_node].df_nodes = df_nodes; + } + } + add_execution_depedences(graph, execution_graph, node_map) +} + +fn partition_components( + graph: &Dag, + execution_graph: &mut Dag, +) -> Result<(), SyncComputeError> { + // First sort the DAG in a schedulable order + let ts = daggy::petgraph::algo::toposort(graph, None) + .map_err(|_| SyncComputeError::UnsatisfiedDependence)?; + let tsmap: HashMap<&NodeIndex, usize> = ts.iter().enumerate().map(|(c, x)| (x, c)).collect(); + let mut vis = graph.visit_map(); + // Traverse the DAG and build a graph of the connected components + for nidx in ts.iter() { + if !vis.is_visited(nidx) { + vis.visit(*nidx); + let mut df_nodes = vec![*nidx]; + let mut stack = vec![*nidx]; + // DFS from the entry point undirected to gather all nodes + // in the component + while let Some(n) = stack.pop() { + for neighbor in graph.graph().neighbors_undirected(n) { + if !vis.is_visited(&neighbor) { + df_nodes.push(neighbor); + stack.push(neighbor); + vis.visit(neighbor); + } + } + } + // Apply topsort to component nodes + df_nodes.sort_by_key(|x| { + tsmap + .get(x) + .ok_or(SyncComputeError::UnsatisfiedDependence) + .unwrap() + }); + execution_graph + .add_node(ExecNode { + df_nodes, + dependence_counter: AtomicUsize::new(0), + }) + .index(); + } + } + // As this partition is made by coalescing all connected + // components within the DFG, there are no dependences (edges) to + // add to the execution graph. + Ok(()) +} + +pub fn execute_partition( + computations: Vec<(i32, Vec, NodeIndex)>, + task_id: NodeIndex, +) -> Result<(Vec<(usize, InMemoryCiphertext)>, NodeIndex), SyncComputeError> { + let mut res: HashMap = HashMap::with_capacity(computations.len()); + for (opcode, inputs, nidx) in computations { + let mut cts = Vec::with_capacity(inputs.len()); + for i in inputs.iter() { + match i { + DFGTaskInput::Dep(d) => { + if let Some(d) = d { + if let Some(ct) = res.get(d) { + cts.push(ct.expanded.clone()); + } + } else { + return Err(SyncComputeError::ComputationFailed); + } + } + DFGTaskInput::Val(v) => { + cts.push(v.clone()); + } + } + } + let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; + res.insert(node_index, result); + } + Ok((Vec::from_iter(res), task_id)) } diff --git a/fhevm-engine/executor/src/dfg/types.rs b/fhevm-engine/executor/src/dfg/types.rs index 123a25e1..616cc88d 100644 --- a/fhevm-engine/executor/src/dfg/types.rs +++ b/fhevm-engine/executor/src/dfg/types.rs @@ -1,11 +1,10 @@ -use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts}; - use crate::server::InMemoryCiphertext; +use fhevm_engine_common::types::SupportedFheCiphertexts; pub type DFGTaskResult = Option; #[derive(Clone)] pub enum DFGTaskInput { Val(SupportedFheCiphertexts), - Handle(Handle), + Dep(Option), } diff --git a/fhevm-engine/executor/tests/scheduling_mapping.rs b/fhevm-engine/executor/tests/scheduling_mapping.rs new file mode 100644 index 00000000..43a0308a --- /dev/null +++ b/fhevm-engine/executor/tests/scheduling_mapping.rs @@ -0,0 +1,180 @@ +use executor::server::common::FheOperation; +use executor::server::executor::sync_compute_response::Resp; +use executor::server::executor::CompressedCiphertext; +use executor::server::executor::{ + fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest, +}; +use executor::server::executor::{sync_input::Input, SyncInput}; +use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN}; +use std::time::SystemTime; +use tfhe::prelude::CiphertextList; +use tfhe::zk::ZkComputeLoad; +use tfhe::ProvenCompactCiphertextList; +use utils::get_test; +mod utils; + +fn get_handle(h: u32) -> Vec { + let tmp = [h; HANDLE_LEN / 4]; + let res: [u8; HANDLE_LEN] = unsafe { std::mem::transmute(tmp) }; + res.to_vec() +} + +#[tokio::test] +async fn schedule_multi_erc20() { + let mut num_samples: usize = 2; + let samples = std::env::var("FHEVM_TEST_NUM_SAMPLES"); + if let Ok(samples) = samples { + num_samples = samples.parse::().unwrap(); + } + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap() + .max_decoding_message_size(usize::MAX); + let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); + let list = builder + .push(100_u64) // Balance source + .push(10_u64) // Transfer amount + .push(20_u64) // Balance destination + .push(0_u64) // 0 + .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) + .unwrap(); + let expander = list.expand_without_verification().unwrap(); + let bals = SupportedFheCiphertexts::FheUint64(expander.get(0).unwrap().unwrap()); + let bals = test.compress(bals); + let trxa = SupportedFheCiphertexts::FheUint64(expander.get(1).unwrap().unwrap()); + let trxa = test.compress(trxa); + let bald = SupportedFheCiphertexts::FheUint64(expander.get(2).unwrap().unwrap()); + let bald = test.compress(bald); + let zero = SupportedFheCiphertexts::FheUint64(expander.get(3).unwrap().unwrap()); + let zero = test.compress(zero); + let handle_bals = test.ciphertext_handle(&bals, 5); + let sync_input_bals = SyncInput { + input: Some(Input::Handle(handle_bals.clone())), + }; + let handle_trxa = test.ciphertext_handle(&trxa, 5); + let sync_input_trxa = SyncInput { + input: Some(Input::Handle(handle_trxa.clone())), + }; + let handle_bald = test.ciphertext_handle(&bald, 5); + let sync_input_bald = SyncInput { + input: Some(Input::Handle(handle_bald.clone())), + }; + let handle_zero = test.ciphertext_handle(&zero, 5); + let sync_input_zero = SyncInput { + input: Some(Input::Handle(handle_zero.clone())), + }; + + let mut computed_handles = vec![]; + for i in 0..=(num_samples * 4 - 1) as u32 { + let input = Some(Input::Handle(get_handle(i))); + computed_handles.push(SyncInput { input }); + } + + let mut computations = vec![]; + for i in 0..=(num_samples - 1) as u32 { + computations.push(SyncComputation { + operation: FheOperation::FheLe.into(), + result_handles: vec![get_handle(i * 4)], + inputs: vec![sync_input_trxa.clone(), sync_input_bals.clone()], + }); // Compare trxa <= bals + computations.push(SyncComputation { + operation: FheOperation::FheIfThenElse.into(), + result_handles: vec![get_handle(i * 4 + 1)], + inputs: vec![ + computed_handles[(i * 4) as usize].clone(), + sync_input_trxa.clone(), + sync_input_zero.clone(), + ], + }); // if trxa <= bals then trxa else zero + computations.push(SyncComputation { + operation: FheOperation::FheSub.into(), + result_handles: vec![get_handle(i * 4 + 2)], + inputs: vec![ + sync_input_bals.clone(), + computed_handles[(i * 4 + 1) as usize].clone(), + ], + }); // bals - trxa/zero + computations.push(SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![get_handle(i * 4 + 3)], + inputs: vec![ + sync_input_bald.clone(), + computed_handles[(i * 4 + 1) as usize].clone(), + ], + }); // bald + trxa/zero + } + let req = SyncComputeRequest { + computations, + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle_bals, + serialization: bals, + }, + CompressedCiphertext { + handle: handle_trxa, + serialization: trxa, + }, + CompressedCiphertext { + handle: handle_bald, + serialization: bald, + }, + CompressedCiphertext { + handle: handle_zero, + serialization: zero, + }, + ], + }; + let now = SystemTime::now(); + let response = client.sync_compute(req).await.unwrap(); + println!("Execution time: {}", now.elapsed().unwrap().as_millis()); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == num_samples * 4, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + num_samples * 4 + ); + for i in 0..=(num_samples * 4 - 1) as u32 { + match &cts.ciphertexts[i as usize].handle { + a if *a == get_handle(i) => { + let mut tt = 0; + if i % 4 != 0 { + tt = 5; + } + let ctd = SupportedFheCiphertexts::decompress( + tt, + &cts.ciphertexts[i as usize].serialization, + ) + .unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "true" if i % 4 == 0 => (), // trxa <= bals true + "10" if i % 4 == 1 => (), // select trxa + "90" if i % 4 == 2 => (), // bals - trxa + "30" if i % 4 == 3 => (), // bald + trxa + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, cts.ciphertexts[i as usize].handle[0] + ), + } + } + _ => assert!( + false, + "unexpected handle 0x{:x}", + cts.ciphertexts[i as usize].handle[0] + ), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} diff --git a/fhevm-engine/executor/tests/scheduling_patterns.rs b/fhevm-engine/executor/tests/scheduling_patterns.rs new file mode 100644 index 00000000..7e709701 --- /dev/null +++ b/fhevm-engine/executor/tests/scheduling_patterns.rs @@ -0,0 +1,840 @@ +use executor::server::common::FheOperation; +use executor::server::executor::sync_compute_response::Resp; +use executor::server::executor::CompressedCiphertext; +use executor::server::executor::{ + fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest, +}; +use executor::server::executor::{sync_input::Input, SyncInput}; +use executor::server::SyncComputeError; +use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN}; +use tfhe::prelude::CiphertextList; +use tfhe::zk::ZkComputeLoad; +use tfhe::ProvenCompactCiphertextList; +use utils::get_test; + +mod utils; + +#[tokio::test] +async fn schedule_circular_dependence() { + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let sync_input1 = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input2 = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input3 = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + let computation1 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input1], + }; + let computation2 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input2], + }; + let computation3 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input3], + }; + let req = SyncComputeRequest { + computations: vec![computation1, computation2, computation3], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(_cts) => assert!( + false, + "Received ciphertext outputs despite circular dependence." + ), + Resp::Error(e) => assert!( + e == SyncComputeError::UnsatisfiedDependence as i32, + "Error response should be UnsatisfiedDependence but is {}", + e + ), + } +} + +#[tokio::test] +async fn schedule_dependent_computations() { + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); + let list = builder + .push(3_u16) + .push(5_u16) + .push(7_u16) + .push(11_u16) + .push(13_u16) + .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) + .unwrap(); + let expander = list.expand_without_verification().unwrap(); + let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); + let ct2 = test.compress(ct2); + let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); + let ct3 = test.compress(ct3); + let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); + let ct4 = test.compress(ct4); + let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); + let ct5 = test.compress(ct5); + let handle1 = test.ciphertext_handle(&ct1, 3); + let sync_input1 = SyncInput { + input: Some(Input::Handle(handle1.clone())), + }; + let handle2 = test.ciphertext_handle(&ct2, 3); + let sync_input2 = SyncInput { + input: Some(Input::Handle(handle2.clone())), + }; + let handle3 = test.ciphertext_handle(&ct3, 3); + let sync_input3 = SyncInput { + input: Some(Input::Handle(handle3.clone())), + }; + let handle4 = test.ciphertext_handle(&ct4, 3); + let sync_input4 = SyncInput { + input: Some(Input::Handle(handle4.clone())), + }; + let handle5 = test.ciphertext_handle(&ct5, 3); + let sync_input5 = SyncInput { + input: Some(Input::Handle(handle5.clone())), + }; + let sync_input6 = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input7 = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input8 = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + let computation1 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1, sync_input2], + }; + let computation2 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input3, sync_input4], + }; + let computation3 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input6, sync_input7], + }; + let computation4 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xdd; HANDLE_LEN]], + inputs: vec![sync_input5, sync_input8], + }; + let req = SyncComputeRequest { + computations: vec![computation4, computation3, computation2, computation1], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle1, + serialization: ct1, + }, + CompressedCiphertext { + handle: handle2, + serialization: ct2, + }, + CompressedCiphertext { + handle: handle3, + serialization: ct3, + }, + CompressedCiphertext { + handle: handle4, + serialization: ct4, + }, + CompressedCiphertext { + handle: handle5, + serialization: ct5, + }, + ], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == 4, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + 4 + ); + let aa: Vec = vec![0xaa; HANDLE_LEN]; + let bb: Vec = vec![0xbb; HANDLE_LEN]; + let cc: Vec = vec![0xcc; HANDLE_LEN]; + let dd: Vec = vec![0xdd; HANDLE_LEN]; + for ct in cts.ciphertexts.iter() { + match &ct.handle { + a if *a == aa => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "8" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + b if *b == bb => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "18" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + c if *c == cc => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "26" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + d if *d == dd => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "39" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} + +#[tokio::test] +async fn schedule_y_patterns() { + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); + let list = builder + .push(1_u16) + .push(2_u16) + .push(3_u16) + .push(4_u16) + .push(5_u16) + .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) + .unwrap(); + let expander = list.expand_without_verification().unwrap(); + let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); + let ct2 = test.compress(ct2); + let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); + let ct3 = test.compress(ct3); + let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); + let ct4 = test.compress(ct4); + let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); + let ct5 = test.compress(ct5); + let handle1 = test.ciphertext_handle(&ct1, 3); + let sync_input1 = SyncInput { + input: Some(Input::Handle(handle1.clone())), + }; + let handle2 = test.ciphertext_handle(&ct2, 3); + let sync_input2 = SyncInput { + input: Some(Input::Handle(handle2.clone())), + }; + let handle3 = test.ciphertext_handle(&ct3, 3); + let sync_input3 = SyncInput { + input: Some(Input::Handle(handle3.clone())), + }; + let handle4 = test.ciphertext_handle(&ct4, 3); + let sync_input4 = SyncInput { + input: Some(Input::Handle(handle4.clone())), + }; + let handle5 = test.ciphertext_handle(&ct5, 3); + let sync_input5 = SyncInput { + input: Some(Input::Handle(handle5.clone())), + }; + let sync_input_aa = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input_bb = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input_cc = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + let sync_input_dd = SyncInput { + input: Some(Input::Handle(vec![0xdd; HANDLE_LEN])), + }; + let sync_input_ee = SyncInput { + input: Some(Input::Handle(vec![0xee; HANDLE_LEN])), + }; + let sync_input_ff = SyncInput { + input: Some(Input::Handle(vec![0xff; HANDLE_LEN])), + }; + let sync_input_99 = SyncInput { + input: Some(Input::Handle(vec![0x99; HANDLE_LEN])), + }; + // Pattern Y + let computation1 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1.clone(), sync_input1.clone()], + }; // Compute 1 + 1 + let computation2 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input2.clone(), sync_input2.clone()], + }; // 2 + 2 + let computation3 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input_aa.clone(), sync_input_bb.clone()], + }; // 2 + 4 + let computation4 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xdd; HANDLE_LEN]], + inputs: vec![sync_input_cc.clone(), sync_input3.clone()], + }; // 6 + 3 + // Pattern reverse Y + let computation5 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xee; HANDLE_LEN]], + inputs: vec![sync_input1.clone(), sync_input1.clone()], + }; // 1 + 1 + let computation6 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xff; HANDLE_LEN]], + inputs: vec![sync_input_ee.clone(), sync_input2.clone()], + }; // 2 + 2 + let computation7 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x99; HANDLE_LEN]], + inputs: vec![sync_input5.clone(), sync_input_ff.clone()], + }; // 5 + 4 + let computation8 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x88; HANDLE_LEN]], + inputs: vec![sync_input3.clone(), sync_input_ff.clone()], + }; // 3 + 4 + let req = SyncComputeRequest { + computations: vec![ + computation4, + computation3, + computation2, + computation1, + computation5, + computation6, + computation7, + computation8, + ], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle1, + serialization: ct1, + }, + CompressedCiphertext { + handle: handle2, + serialization: ct2, + }, + CompressedCiphertext { + handle: handle3, + serialization: ct3, + }, + CompressedCiphertext { + handle: handle4, + serialization: ct4, + }, + CompressedCiphertext { + handle: handle5, + serialization: ct5, + }, + ], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == 8, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + 8 + ); + let aa: Vec = vec![0xaa; HANDLE_LEN]; + let bb: Vec = vec![0xbb; HANDLE_LEN]; + let cc: Vec = vec![0xcc; HANDLE_LEN]; + let dd: Vec = vec![0xdd; HANDLE_LEN]; + let ee: Vec = vec![0xee; HANDLE_LEN]; + let ff: Vec = vec![0xff; HANDLE_LEN]; + let x88: Vec = vec![0x88; HANDLE_LEN]; + let x99: Vec = vec![0x99; HANDLE_LEN]; + for ct in cts.ciphertexts.iter() { + match &ct.handle { + a if *a == aa => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "2" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + b if *b == bb => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "4" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + c if *c == cc => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "6" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + d if *d == dd => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "9" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + e if *e == ee => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "2" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + f if *f == ff => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "4" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x99 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "9" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x88 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "7" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} + +#[tokio::test] +async fn schedule_diamond_reduction_dependence_pattern() { + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); + let list = builder + .push(1_u16) + .push(2_u16) + .push(3_u16) + .push(4_u16) + .push(5_u16) + .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) + .unwrap(); + let expander = list.expand_without_verification().unwrap(); + let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); + let ct2 = test.compress(ct2); + let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); + let ct3 = test.compress(ct3); + let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); + let ct4 = test.compress(ct4); + let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); + let ct5 = test.compress(ct5); + let handle1 = test.ciphertext_handle(&ct1, 3); + let sync_input1 = SyncInput { + input: Some(Input::Handle(handle1.clone())), + }; + let handle2 = test.ciphertext_handle(&ct2, 3); + let sync_input2 = SyncInput { + input: Some(Input::Handle(handle2.clone())), + }; + let handle3 = test.ciphertext_handle(&ct3, 3); + let sync_input3 = SyncInput { + input: Some(Input::Handle(handle3.clone())), + }; + let handle4 = test.ciphertext_handle(&ct4, 3); + let sync_input4 = SyncInput { + input: Some(Input::Handle(handle4.clone())), + }; + let handle5 = test.ciphertext_handle(&ct5, 3); + let sync_input5 = SyncInput { + input: Some(Input::Handle(handle5.clone())), + }; + let sync_input_aa = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input_bb = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input_cc = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + let sync_input_dd = SyncInput { + input: Some(Input::Handle(vec![0xdd; HANDLE_LEN])), + }; + let sync_input_ee = SyncInput { + input: Some(Input::Handle(vec![0xee; HANDLE_LEN])), + }; + let sync_input_ff = SyncInput { + input: Some(Input::Handle(vec![0xff; HANDLE_LEN])), + }; + let sync_input_99 = SyncInput { + input: Some(Input::Handle(vec![0x99; HANDLE_LEN])), + }; + let computation1 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1.clone(), sync_input1], + }; // Compute 1 + 1 + let computation2 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input2, sync_input_aa.clone()], + }; // 2 + 2 + let computation3 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input3, sync_input_aa.clone()], + }; // 2 + 3 + let computation4 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xdd; HANDLE_LEN]], + inputs: vec![sync_input4, sync_input_aa.clone()], + }; // 2 + 4 + let computation5 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xee; HANDLE_LEN]], + inputs: vec![sync_input5, sync_input_aa.clone()], + }; // 2 + 5 + let computation6 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xff; HANDLE_LEN]], + inputs: vec![sync_input_bb, sync_input_cc], + }; // 4 + 5 + let computation7 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x99; HANDLE_LEN]], + inputs: vec![sync_input_dd, sync_input_ee], + }; // 6 + 7 + let computation8 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x88; HANDLE_LEN]], + inputs: vec![sync_input_ff, sync_input_99], + }; // 9 + 13 + let req = SyncComputeRequest { + computations: vec![ + computation4, + computation3, + computation2, + computation1, + computation5, + computation6, + computation7, + computation8, + ], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle1, + serialization: ct1, + }, + CompressedCiphertext { + handle: handle2, + serialization: ct2, + }, + CompressedCiphertext { + handle: handle3, + serialization: ct3, + }, + CompressedCiphertext { + handle: handle4, + serialization: ct4, + }, + CompressedCiphertext { + handle: handle5, + serialization: ct5, + }, + ], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == 8, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + 8 + ); + let aa: Vec = vec![0xaa; HANDLE_LEN]; + let bb: Vec = vec![0xbb; HANDLE_LEN]; + let cc: Vec = vec![0xcc; HANDLE_LEN]; + let dd: Vec = vec![0xdd; HANDLE_LEN]; + let ee: Vec = vec![0xee; HANDLE_LEN]; + let ff: Vec = vec![0xff; HANDLE_LEN]; + let x88: Vec = vec![0x88; HANDLE_LEN]; + let x99: Vec = vec![0x99; HANDLE_LEN]; + for ct in cts.ciphertexts.iter() { + match &ct.handle { + a if *a == aa => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "2" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + b if *b == bb => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "4" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + c if *c == cc => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "5" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + d if *d == dd => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "6" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + e if *e == ee => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "7" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + f if *f == ff => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "9" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x99 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "13" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x88 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "22" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} diff --git a/fhevm-engine/executor/tests/sync_compute.rs b/fhevm-engine/executor/tests/sync_compute.rs index ff4c2b15..00fdee56 100644 --- a/fhevm-engine/executor/tests/sync_compute.rs +++ b/fhevm-engine/executor/tests/sync_compute.rs @@ -273,541 +273,3 @@ async fn compute_on_result_ciphertext() { Resp::Error(e) => assert!(false, "error response: {}", e), } } - -#[tokio::test] -async fn schedule_dependent_computations() { - let test = get_test().await; - test.keys.set_server_key_for_current_thread(); - let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) - .await - .unwrap(); - let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); - let list = builder - .push(3_u16) - .push(5_u16) - .push(7_u16) - .push(11_u16) - .push(13_u16) - .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) - .unwrap(); - let expander = list.expand_without_verification().unwrap(); - let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); - let ct1 = test.compress(ct1); - let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); - let ct2 = test.compress(ct2); - let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); - let ct3 = test.compress(ct3); - let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); - let ct4 = test.compress(ct4); - let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); - let ct5 = test.compress(ct5); - let handle1 = test.ciphertext_handle(&ct1, 3); - let sync_input1 = SyncInput { - input: Some(Input::Handle(handle1.clone())), - }; - let handle2 = test.ciphertext_handle(&ct2, 3); - let sync_input2 = SyncInput { - input: Some(Input::Handle(handle2.clone())), - }; - let handle3 = test.ciphertext_handle(&ct3, 3); - let sync_input3 = SyncInput { - input: Some(Input::Handle(handle3.clone())), - }; - let handle4 = test.ciphertext_handle(&ct4, 3); - let sync_input4 = SyncInput { - input: Some(Input::Handle(handle4.clone())), - }; - let handle5 = test.ciphertext_handle(&ct5, 3); - let sync_input5 = SyncInput { - input: Some(Input::Handle(handle5.clone())), - }; - let sync_input6 = SyncInput { - input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), - }; - let sync_input7 = SyncInput { - input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), - }; - let sync_input8 = SyncInput { - input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), - }; - - let computation1 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xaa; HANDLE_LEN]], - inputs: vec![sync_input1, sync_input2], - }; - let computation2 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xbb; HANDLE_LEN]], - inputs: vec![sync_input3, sync_input4], - }; - let computation3 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xcc; HANDLE_LEN]], - inputs: vec![sync_input6, sync_input7], - }; - let computation4 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xdd; HANDLE_LEN]], - inputs: vec![sync_input5, sync_input8], - }; - - let req = SyncComputeRequest { - computations: vec![computation4, computation3, computation2, computation1], - compact_ciphertext_lists: vec![], - compressed_ciphertexts: vec![ - CompressedCiphertext { - handle: handle1, - serialization: ct1, - }, - CompressedCiphertext { - handle: handle2, - serialization: ct2, - }, - CompressedCiphertext { - handle: handle3, - serialization: ct3, - }, - CompressedCiphertext { - handle: handle4, - serialization: ct4, - }, - CompressedCiphertext { - handle: handle5, - serialization: ct5, - }, - ], - }; - let response = client.sync_compute(req).await.unwrap(); - let sync_compute_response = response.get_ref(); - let resp = sync_compute_response.resp.clone().unwrap(); - match resp { - Resp::ResultCiphertexts(cts) => { - assert!( - cts.ciphertexts.len() == 4, - "wrong number of output ciphertexts {} instead of {}", - cts.ciphertexts.len(), - 4 - ); - let aa: Vec = vec![0xaa; HANDLE_LEN]; - let bb: Vec = vec![0xbb; HANDLE_LEN]; - let cc: Vec = vec![0xcc; HANDLE_LEN]; - let dd: Vec = vec![0xdd; HANDLE_LEN]; - for ct in cts.ciphertexts.iter() { - match &ct.handle { - a if *a == aa => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "8" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - b if *b == bb => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "18" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - c if *c == cc => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "26" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - d if *d == dd => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "39" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), - } - } - } - Resp::Error(e) => assert!(false, "error response: {}", e), - } -} - -#[tokio::test] -async fn schedule_circular_dependence() { - let test = get_test().await; - test.keys.set_server_key_for_current_thread(); - let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) - .await - .unwrap(); - let sync_input1 = SyncInput { - input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), - }; - let sync_input2 = SyncInput { - input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), - }; - let sync_input3 = SyncInput { - input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), - }; - - let computation1 = SyncComputation { - operation: FheOperation::FheNeg.into(), - result_handles: vec![vec![0xbb; HANDLE_LEN]], - inputs: vec![sync_input1], - }; - let computation2 = SyncComputation { - operation: FheOperation::FheNeg.into(), - result_handles: vec![vec![0xcc; HANDLE_LEN]], - inputs: vec![sync_input2], - }; - let computation3 = SyncComputation { - operation: FheOperation::FheNeg.into(), - result_handles: vec![vec![0xaa; HANDLE_LEN]], - inputs: vec![sync_input3], - }; - - let req = SyncComputeRequest { - computations: vec![computation1, computation2, computation3], - compact_ciphertext_lists: vec![], - compressed_ciphertexts: vec![], - }; - let response = client.sync_compute(req).await.unwrap(); - let sync_compute_response = response.get_ref(); - let resp = sync_compute_response.resp.clone().unwrap(); - match resp { - Resp::ResultCiphertexts(_cts) => assert!( - false, - "Received ciphertext outputs despite circular dependence." - ), - Resp::Error(e) => assert!( - e == SyncComputeError::UnsatisfiedDependence as i32, - "Error response should be UnsatisfiedDependence but is {}", - e - ), - } -} - -#[tokio::test] -async fn schedule_diamond_reduction_dependence_pattern() { - let test = get_test().await; - test.keys.set_server_key_for_current_thread(); - let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) - .await - .unwrap(); - let mut builder = ProvenCompactCiphertextList::builder(&test.keys.compact_public_key); - let list = builder - .push(1_u16) - .push(2_u16) - .push(3_u16) - .push(4_u16) - .push(5_u16) - .build_with_proof_packed(&test.keys.public_params, &[], ZkComputeLoad::Proof) - .unwrap(); - let expander = list.expand_without_verification().unwrap(); - let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); - let ct1 = test.compress(ct1); - let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); - let ct2 = test.compress(ct2); - let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); - let ct3 = test.compress(ct3); - let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); - let ct4 = test.compress(ct4); - let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); - let ct5 = test.compress(ct5); - let handle1 = test.ciphertext_handle(&ct1, 3); - let sync_input1 = SyncInput { - input: Some(Input::Handle(handle1.clone())), - }; - let handle2 = test.ciphertext_handle(&ct2, 3); - let sync_input2 = SyncInput { - input: Some(Input::Handle(handle2.clone())), - }; - let handle3 = test.ciphertext_handle(&ct3, 3); - let sync_input3 = SyncInput { - input: Some(Input::Handle(handle3.clone())), - }; - let handle4 = test.ciphertext_handle(&ct4, 3); - let sync_input4 = SyncInput { - input: Some(Input::Handle(handle4.clone())), - }; - let handle5 = test.ciphertext_handle(&ct5, 3); - let sync_input5 = SyncInput { - input: Some(Input::Handle(handle5.clone())), - }; - let sync_input_aa = SyncInput { - input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), - }; - let sync_input_bb = SyncInput { - input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), - }; - let sync_input_cc = SyncInput { - input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), - }; - let sync_input_dd = SyncInput { - input: Some(Input::Handle(vec![0xdd; HANDLE_LEN])), - }; - let sync_input_ee = SyncInput { - input: Some(Input::Handle(vec![0xee; HANDLE_LEN])), - }; - let sync_input_ff = SyncInput { - input: Some(Input::Handle(vec![0xff; HANDLE_LEN])), - }; - let sync_input_99 = SyncInput { - input: Some(Input::Handle(vec![0x99; HANDLE_LEN])), - }; - - let computation1 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xaa; HANDLE_LEN]], - inputs: vec![sync_input1.clone(), sync_input1], - }; // Compute 1 + 1 - let computation2 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xbb; HANDLE_LEN]], - inputs: vec![sync_input2, sync_input_aa.clone()], - }; // 2 + 2 - let computation3 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xcc; HANDLE_LEN]], - inputs: vec![sync_input3, sync_input_aa.clone()], - }; // 2 + 3 - let computation4 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xdd; HANDLE_LEN]], - inputs: vec![sync_input4, sync_input_aa.clone()], - }; // 2 + 4 - let computation5 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xee; HANDLE_LEN]], - inputs: vec![sync_input5, sync_input_aa.clone()], - }; // 2 + 5 - - let computation6 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0xff; HANDLE_LEN]], - inputs: vec![sync_input_bb, sync_input_cc], - }; // 4 + 5 - let computation7 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0x99; HANDLE_LEN]], - inputs: vec![sync_input_dd, sync_input_ee], - }; // 6 + 7 - let computation8 = SyncComputation { - operation: FheOperation::FheAdd.into(), - result_handles: vec![vec![0x88; HANDLE_LEN]], - inputs: vec![sync_input_ff, sync_input_99], - }; // 9 + 13 - - let req = SyncComputeRequest { - computations: vec![ - computation4, - computation3, - computation2, - computation1, - computation5, - computation6, - computation7, - computation8, - ], - compact_ciphertext_lists: vec![], - compressed_ciphertexts: vec![ - CompressedCiphertext { - handle: handle1, - serialization: ct1, - }, - CompressedCiphertext { - handle: handle2, - serialization: ct2, - }, - CompressedCiphertext { - handle: handle3, - serialization: ct3, - }, - CompressedCiphertext { - handle: handle4, - serialization: ct4, - }, - CompressedCiphertext { - handle: handle5, - serialization: ct5, - }, - ], - }; - let response = client.sync_compute(req).await.unwrap(); - let sync_compute_response = response.get_ref(); - let resp = sync_compute_response.resp.clone().unwrap(); - match resp { - Resp::ResultCiphertexts(cts) => { - assert!( - cts.ciphertexts.len() == 8, - "wrong number of output ciphertexts {} instead of {}", - cts.ciphertexts.len(), - 8 - ); - let aa: Vec = vec![0xaa; HANDLE_LEN]; - let bb: Vec = vec![0xbb; HANDLE_LEN]; - let cc: Vec = vec![0xcc; HANDLE_LEN]; - let dd: Vec = vec![0xdd; HANDLE_LEN]; - let ee: Vec = vec![0xee; HANDLE_LEN]; - let ff: Vec = vec![0xff; HANDLE_LEN]; - let x88: Vec = vec![0x88; HANDLE_LEN]; - let x99: Vec = vec![0x99; HANDLE_LEN]; - for ct in cts.ciphertexts.iter() { - match &ct.handle { - a if *a == aa => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "2" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - b if *b == bb => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "4" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - c if *c == cc => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "5" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - d if *d == dd => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "6" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - e if *e == ee => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "7" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - f if *f == ff => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "9" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - x if *x == x99 => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "13" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - x if *x == x88 => { - let ctd = - SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); - match ctd - .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) - .as_str() - { - "22" => (), - s => assert!( - false, - "unexpected result: {} for handle 0x{:x}", - s, ct.handle[0] - ), - } - } - _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), - } - } - } - Resp::Error(e) => assert!(false, "error response: {}", e), - } -} From 18dbf63640ed42e34bdeea2a38fdbcbeff6b520d Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Wed, 9 Oct 2024 15:15:46 +0100 Subject: [PATCH 2/4] feat(scheduler): add rayon thread pool management This will restrict the rayon thread pool exposed to TFHE-rs in order to improve the scheduling of concurrent FHE operations. Each re-usable thread pool is used exclusively for one scheduler task, which improves cache locality and reduces interference from other operations. --- fhevm-engine/Cargo.lock | 1 + fhevm-engine/Cargo.toml | 1 + fhevm-engine/executor/Cargo.toml | 1 + fhevm-engine/executor/src/dfg/scheduler.rs | 71 +++++++++++++++++++--- fhevm-engine/executor/src/server.rs | 26 +++++++- 5 files changed, 90 insertions(+), 10 deletions(-) diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index 272007f6..337cf6c3 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -2198,6 +2198,7 @@ dependencies = [ "daggy", "fhevm-engine-common", "prost", + "rayon", "sha3", "tfhe", "tokio", diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index e474d5da..829df784 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -15,6 +15,7 @@ serde = "1.0.210" prometheus = "0.13.4" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["fmt", "json"] } +rayon = "1.10.0" [profile.dev.package.tfhe] overflow-checks = false diff --git a/fhevm-engine/executor/Cargo.toml b/fhevm-engine/executor/Cargo.toml index 64e8e333..4e67f2f0 100644 --- a/fhevm-engine/executor/Cargo.toml +++ b/fhevm-engine/executor/Cargo.toml @@ -17,6 +17,7 @@ bincode.workspace = true sha3.workspace = true anyhow.workspace = true daggy.workspace = true +rayon.workspace = true fhevm-engine-common = { path = "../fhevm-engine-common" } [build-dependencies] diff --git a/fhevm-engine/executor/src/dfg/scheduler.rs b/fhevm-engine/executor/src/dfg/scheduler.rs index f0d49fea..dbf4b21e 100644 --- a/fhevm-engine/executor/src/dfg/scheduler.rs +++ b/fhevm-engine/executor/src/dfg/scheduler.rs @@ -1,9 +1,10 @@ +use std::borrow::Borrow; use std::collections::HashMap; use std::sync::atomic::AtomicUsize; use crate::dfg::types::*; use crate::dfg::{OpEdge, OpNode}; -use crate::server::{run_computation, InMemoryCiphertext, SyncComputeError}; +use crate::server::{self, run_computation, InMemoryCiphertext, SyncComputeError}; use anyhow::Result; use daggy::petgraph::csr::IndexType; use daggy::petgraph::graph::node_index; @@ -11,6 +12,9 @@ use daggy::petgraph::visit::{IntoEdgeReferences, IntoNeighbors, VisitMap, Visita use daggy::petgraph::Direction::Incoming; use fhevm_engine_common::types::SupportedFheCiphertexts; +use rayon::prelude::*; +use std::sync::mpsc::channel; + use daggy::{ petgraph::{ visit::{EdgeRef, IntoEdgesDirected}, @@ -80,11 +84,11 @@ impl<'a> Scheduler<'a> { self.schedule_coarse_grain(PartitionStrategy::MaxLocality) .await } - Ok(val) if val == "LOOP" => panic!("Unimplemented LOOP scheduling strategy"), + Ok(val) if val == "LOOP" => self.schedule_component_loop().await, Ok(val) if val == "FINE_GRAIN" => self.schedule_fine_grain().await, Ok(unhandled) => panic!("Scheduling strategy {:?} does not exist", unhandled), - _ => self.schedule_fine_grain().await, + _ => self.schedule_component_loop().await, } } @@ -168,7 +172,7 @@ impl<'a> Scheduler<'a> { let opcode = n.opcode; args.push((opcode, std::mem::take(&mut n.inputs), *nidx)); } - set.spawn_blocking(move || execute_partition(args, index)); + set.spawn_blocking(move || execute_partition(args, index, false)); } } // Get results from computations and update dependences of remaining computations @@ -204,8 +208,47 @@ impl<'a> Scheduler<'a> { let opcode = n.opcode; args.push((opcode, std::mem::take(&mut n.inputs), *nidx)); } - set.spawn_blocking(move || execute_partition(args, dependent_task_index)); + set.spawn_blocking(move || { + execute_partition(args, dependent_task_index, false) + }); + } + } + } + Ok(()) + } + + async fn schedule_component_loop(&mut self) -> Result<(), SyncComputeError> { + let mut execution_graph: Dag = Dag::default(); + let _ = partition_components(self.graph, &mut execution_graph); + let mut comps = vec![]; + + // Prime the scheduler with all nodes without dependences + for idx in 0..execution_graph.node_count() { + let index = NodeIndex::new(idx); + let node = execution_graph.node_weight_mut(index).unwrap(); + if self.is_ready_task(node) { + let mut args = Vec::with_capacity(node.df_nodes.len()); + for nidx in node.df_nodes.iter() { + let n = self.graph.node_weight_mut(*nidx).unwrap(); + let opcode = n.opcode; + args.push((opcode, std::mem::take(&mut n.inputs), *nidx)); } + comps.push((std::mem::take(&mut args), index)); + } + } + + let (src, dest) = channel(); + comps.par_iter().for_each_with(src, |src, (args, index)| { + src.send(execute_partition(args.to_vec(), *index, true)) + .unwrap(); + }); + let results: Vec<_> = dest.iter().collect(); + for result in results { + let mut output = result.map_err(|_| SyncComputeError::ComputationFailed)?; + while let Some(o) = output.0.pop() { + let index = o.0; + let node_index = NodeIndex::new(index); + self.graph[node_index].result = Some(o.1); } } Ok(()) @@ -335,6 +378,7 @@ fn partition_components( pub fn execute_partition( computations: Vec<(i32, Vec, NodeIndex)>, task_id: NodeIndex, + use_global_threadpool: bool, ) -> Result<(Vec<(usize, InMemoryCiphertext)>, NodeIndex), SyncComputeError> { let mut res: HashMap = HashMap::with_capacity(computations.len()); for (opcode, inputs, nidx) in computations { @@ -355,8 +399,21 @@ pub fn execute_partition( } } } - let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; - res.insert(node_index, result); + if use_global_threadpool { + let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; + res.insert(node_index, result); + } else { + let thread_pool = server::THREAD_POOL + .borrow() + .take() + .ok_or(SyncComputeError::ComputationFailed)?; + thread_pool.install(|| -> Result<(), SyncComputeError> { + let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; + res.insert(node_index, result); + Ok(()) + })?; + server::THREAD_POOL.set(Some(thread_pool)); + } } Ok((Vec::from_iter(res), task_id)) } diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index f1e64ac8..0e450e92 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use anyhow::Result; pub use common::FheOperation; use executor::{ @@ -16,7 +14,8 @@ use fhevm_engine_common::{ types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN}, }; use sha3::{Digest, Keccak256}; -use tfhe::{integer::U256, set_server_key, zk::CompactPkePublicParams}; +use std::{cell::RefCell, collections::HashMap}; +use tfhe::{integer::U256, set_server_key, zk::CompactPkePublicParams, ServerKey}; use tokio::task::spawn_blocking; use tonic::{transport::Server, Code, Request, Response, Status}; @@ -30,14 +29,29 @@ pub mod executor { tonic::include_proto!("fhevm.executor"); } +thread_local! { + pub static THREAD_POOL: RefCell> = const {RefCell::new(None)}; +} + pub fn start(args: &crate::cli::Args) -> Result<()> { let keys: FhevmKeys = SerializedFhevmKeys::load_from_disk().into(); let executor = FhevmExecutorService::new(); + rayon::broadcast(|_| { + set_server_key(keys.server_key.clone()); + }); let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(args.tokio_threads) .max_blocking_threads(args.fhe_compute_threads) .on_thread_start(move || { set_server_key(keys.server_key.clone()); + let rayon_pool = rayon::ThreadPoolBuilder::new() + .num_threads(8) + .build() + .unwrap(); + rayon_pool.broadcast(|_| { + set_server_key(keys.server_key.clone()); + }); + THREAD_POOL.set(Some(rayon_pool)); }) .enable_all() .build()?; @@ -108,9 +122,15 @@ impl FhevmExecutor for FhevmExecutorService { } // Schedule computations in parallel as dependences allow let mut sched = Scheduler::new(&mut graph.graph); + + let now = std::time::SystemTime::now(); if sched.schedule().await.is_err() { return Some(Resp::Error(SyncComputeError::ComputationFailed.into())); } + println!( + "Execution time (sched): {}", + now.elapsed().unwrap().as_millis() + ); // Extract the results from the graph match graph.get_results() { Ok(result_cts) => Some(Resp::ResultCiphertexts(ResultCiphertexts { From ac6740206caedc2b52527a1b0b2778289babdcda Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Wed, 16 Oct 2024 13:33:11 +0100 Subject: [PATCH 3/4] fix(executor): add FHE operation threads parameter and simplify execution strategies --- fhevm-engine/executor/src/cli.rs | 3 ++ fhevm-engine/executor/src/dfg/scheduler.rs | 33 ++++++++++--------- fhevm-engine/executor/src/server.rs | 3 +- .../executor/tests/scheduling_mapping.rs | 7 ++-- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/fhevm-engine/executor/src/cli.rs b/fhevm-engine/executor/src/cli.rs index bdf4548d..7d8f3d11 100644 --- a/fhevm-engine/executor/src/cli.rs +++ b/fhevm-engine/executor/src/cli.rs @@ -9,6 +9,9 @@ pub struct Args { #[arg(long, default_value_t = 8)] pub fhe_compute_threads: usize, + #[arg(long, default_value_t = 8)] + pub fhe_operation_threads: usize, + #[arg(long, default_value = "127.0.0.1:50051")] pub server_addr: String, } diff --git a/fhevm-engine/executor/src/dfg/scheduler.rs b/fhevm-engine/executor/src/dfg/scheduler.rs index dbf4b21e..286e8e79 100644 --- a/fhevm-engine/executor/src/dfg/scheduler.rs +++ b/fhevm-engine/executor/src/dfg/scheduler.rs @@ -76,18 +76,19 @@ impl<'a> Scheduler<'a> { pub async fn schedule(&mut self) -> Result<(), SyncComputeError> { let schedule_type = std::env::var("FHEVM_DF_SCHEDULE"); match schedule_type { - Ok(val) if val == "MAX_PARALLELISM" => { - self.schedule_coarse_grain(PartitionStrategy::MaxParallelism) - .await - } - Ok(val) if val == "MAX_LOCALITY" => { - self.schedule_coarse_grain(PartitionStrategy::MaxLocality) - .await - } - Ok(val) if val == "LOOP" => self.schedule_component_loop().await, - Ok(val) if val == "FINE_GRAIN" => self.schedule_fine_grain().await, - Ok(unhandled) => panic!("Scheduling strategy {:?} does not exist", unhandled), - + Ok(val) => match val.as_str() { + "MAX_PARALLELISM" => { + self.schedule_coarse_grain(PartitionStrategy::MaxParallelism) + .await + } + "MAX_LOCALITY" => { + self.schedule_coarse_grain(PartitionStrategy::MaxLocality) + .await + } + "LOOP" => self.schedule_component_loop().await, + "FINE_GRAIN" => self.schedule_fine_grain().await, + unhandled => panic!("Scheduling strategy {:?} does not exist", unhandled), + }, _ => self.schedule_component_loop().await, } } @@ -151,14 +152,14 @@ impl<'a> Scheduler<'a> { Result<(Vec<(usize, InMemoryCiphertext)>, NodeIndex), SyncComputeError>, > = JoinSet::new(); let mut execution_graph: Dag = Dag::default(); - match strategy { + let _ = match strategy { PartitionStrategy::MaxLocality => { - let _ = partition_components(self.graph, &mut execution_graph); + partition_components(self.graph, &mut execution_graph) } PartitionStrategy::MaxParallelism => { - let _ = partition_preserving_parallelism(self.graph, &mut execution_graph); + partition_preserving_parallelism(self.graph, &mut execution_graph) } - } + }; let task_dependences = execution_graph.map(|_, _| (), |_, edge| *edge); // Prime the scheduler with all nodes without dependences diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 0e450e92..9014c274 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -36,6 +36,7 @@ thread_local! { pub fn start(args: &crate::cli::Args) -> Result<()> { let keys: FhevmKeys = SerializedFhevmKeys::load_from_disk().into(); let executor = FhevmExecutorService::new(); + let rayon_threads = args.fhe_operation_threads; rayon::broadcast(|_| { set_server_key(keys.server_key.clone()); }); @@ -45,7 +46,7 @@ pub fn start(args: &crate::cli::Args) -> Result<()> { .on_thread_start(move || { set_server_key(keys.server_key.clone()); let rayon_pool = rayon::ThreadPoolBuilder::new() - .num_threads(8) + .num_threads(rayon_threads) .build() .unwrap(); rayon_pool.broadcast(|_| { diff --git a/fhevm-engine/executor/tests/scheduling_mapping.rs b/fhevm-engine/executor/tests/scheduling_mapping.rs index 43a0308a..48f1169b 100644 --- a/fhevm-engine/executor/tests/scheduling_mapping.rs +++ b/fhevm-engine/executor/tests/scheduling_mapping.rs @@ -14,8 +14,11 @@ use utils::get_test; mod utils; fn get_handle(h: u32) -> Vec { - let tmp = [h; HANDLE_LEN / 4]; - let res: [u8; HANDLE_LEN] = unsafe { std::mem::transmute(tmp) }; + let mut res: Vec = Vec::with_capacity(HANDLE_LEN); + let slice: [u8; 4] = h.to_be_bytes(); + for _i in 0..HANDLE_LEN / 4 { + res.extend_from_slice(&slice); + } res.to_vec() } From 695f79af1910becc0e81b690140844733677be50 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 8 Oct 2024 16:22:52 +0100 Subject: [PATCH 4/4] fix(executor): fix optimization profile for release target --- fhevm-engine/Cargo.toml | 8 +++++--- fhevm-engine/fhevm-engine-common/Cargo.toml | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index 829df784..aa5243c3 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -20,8 +20,10 @@ rayon = "1.10.0" [profile.dev.package.tfhe] overflow-checks = false -# for testing in release mode due to too big -# binary inside mac [profile.release] -opt-level = "z" +# for testing in release mode due to too big binary inside mac: +# set opt-level = "z" +# however, this leads to 2-4x slower execution due to loss of loop +# vectorization +opt-level = 3 lto = "fat" diff --git a/fhevm-engine/fhevm-engine-common/Cargo.toml b/fhevm-engine/fhevm-engine-common/Cargo.toml index 1b025978..f0a2c911 100644 --- a/fhevm-engine/fhevm-engine-common/Cargo.toml +++ b/fhevm-engine/fhevm-engine-common/Cargo.toml @@ -4,9 +4,9 @@ version = "0.1.0" edition = "2021" [target.'cfg(target_arch = "x86_64")'.dependencies] -tfhe = { version = "0.8.3", features = ["boolean", "shortint", "integer", "x86_64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] } +tfhe = { version = "0.8.3", features = ["boolean", "shortint", "integer", "x86_64-unix", "zk-pok", "experimental-force_fft_algo_dif4", "nightly-avx512"] } [target.'cfg(target_arch = "aarch64")'.dependencies] -tfhe = { version = "0.8.3", features = ["boolean", "shortint", "integer", "aarch64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] } +tfhe = { version = "0.8.3", features = ["boolean", "shortint", "integer", "aarch64-unix", "zk-pok", "experimental-force_fft_algo_dif4", "nightly-avx512"] } [dependencies] sha3.workspace = true