From 5080aeb2103e0c3fd77779f00a8ac77c8e0283ab Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Sat, 7 Sep 2024 10:20:04 +0100 Subject: [PATCH] feat(executor): add a parallel dataflow scheduler for computations in a SyncComputeRequest The scheduler runs on top of a DFG built using the daggy extension to petgraph and executes computations/tasks along the wavefront of maximum parallelism within the graph. Circular dependences are detected early. --- fhevm-engine/Cargo.lock | 10 + fhevm-engine/Cargo.toml | 1 + fhevm-engine/executor/Cargo.toml | 1 + fhevm-engine/executor/src/dfg.rs | 149 ++++++ fhevm-engine/executor/src/dfg/scheduler.rs | 83 +++ fhevm-engine/executor/src/dfg/types.rs | 11 + fhevm-engine/executor/src/lib.rs | 1 + fhevm-engine/executor/src/main.rs | 1 + fhevm-engine/executor/src/server.rs | 123 +++-- fhevm-engine/executor/tests/sync_compute.rs | 534 ++++++++++++++++++++ proto/executor.proto | 1 + 11 files changed, 874 insertions(+), 41 deletions(-) create mode 100644 fhevm-engine/executor/src/dfg.rs create mode 100644 fhevm-engine/executor/src/dfg/scheduler.rs create mode 100644 fhevm-engine/executor/src/dfg/types.rs diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index 0e735e3f..d1bf90b9 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -1650,6 +1650,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "daggy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91a9304e55e9d601a39ae4deaba85406d5c0980e106f65afcf0460e9af1e7602" +dependencies = [ + "petgraph", +] + [[package]] name = "darling" version = "0.20.10" @@ -1910,6 +1919,7 @@ dependencies = [ "anyhow", "bincode", "clap", + "daggy", "fhevm-engine-common", "prost", "sha3", diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index 7be39b60..bbebd54f 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -10,6 +10,7 @@ tonic = { version = "0.12", features = ["server"] } bincode = "1.3.3" sha3 = "0.10.8" anyhow = "1.0.86" +daggy = "0.8.0" [profile.dev.package.tfhe] overflow-checks = false \ No newline at end of file diff --git a/fhevm-engine/executor/Cargo.toml b/fhevm-engine/executor/Cargo.toml index 9806735f..0752bd1e 100644 --- a/fhevm-engine/executor/Cargo.toml +++ b/fhevm-engine/executor/Cargo.toml @@ -16,6 +16,7 @@ tonic.workspace = true bincode.workspace = true sha3.workspace = true anyhow.workspace = true +daggy.workspace = true fhevm-engine-common = { path = "../fhevm-engine-common" } [build-dependencies] diff --git a/fhevm-engine/executor/src/dfg.rs b/fhevm-engine/executor/src/dfg.rs new file mode 100644 index 00000000..a10de529 --- /dev/null +++ b/fhevm-engine/executor/src/dfg.rs @@ -0,0 +1,149 @@ +pub mod scheduler; +mod types; + +use crate::dfg::types::*; +use crate::server::{ + CompressedCiphertext, ComputationState, Input, SyncComputation, SyncComputeError, + SyncComputeRequest, +}; +use anyhow::Result; +use fhevm_engine_common::types::{ + FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN, +}; +use tfhe::integer::U256; + +use daggy::{Dag, NodeIndex}; +use std::collections::HashMap; + +//TODO#[derive(Debug)] +pub struct Node<'a> { + computation: &'a SyncComputation, + result: DFGTaskResult, + result_handle: Handle, + inputs: Vec, +} +pub type Edge = u8; + +//TODO#[derive(Debug)] +#[derive(Default)] +pub struct DFGraph<'a> { + pub graph: Dag, Edge>, + produced_handles: HashMap<&'a Handle, NodeIndex>, +} + +impl<'a> DFGraph<'a> { + pub fn add_node( + &mut self, + computation: &'a SyncComputation, + inputs: Vec, + ) -> Result { + let rh = computation + .result_handles + .first() + .filter(|h| h.len() == HANDLE_LEN) + .ok_or(SyncComputeError::BadResultHandles)?; + Ok(self.graph.add_node(Node { + computation, + result: None, + result_handle: rh.clone(), + inputs, + })) + } + + pub fn add_dependence( + &mut self, + source: NodeIndex, + destination: NodeIndex, + consumer_input: Edge, + ) -> Result<(), SyncComputeError> { + let _edge = self + .graph + .add_edge(source, destination, consumer_input) + .map_err(|_| SyncComputeError::UnsatisfiedDependence)?; + Ok(()) + } + pub fn build_from_request( + &mut self, + req: &'a SyncComputeRequest, + state: &ComputationState, + ) -> Result<(), SyncComputeError> { + // Add all computations as nodes in the graph. + for computation in &req.computations { + let inputs: Result> = computation + .inputs + .iter() + .map(|input| match &input.input { + Some(input) => match input { + Input::Handle(h) => { + if let Some(ct) = state.ciphertexts.get(h) { + Ok(DFGTaskInput::Val(ct.expanded.clone())) + } else { + Ok(DFGTaskInput::Handle(h.clone())) + } + } + Input::Scalar(s) if s.len() == SCALAR_LEN => { + let mut scalar = U256::default(); + scalar.copy_from_be_byte_slice(s); + Ok(DFGTaskInput::Val(SupportedFheCiphertexts::Scalar(scalar))) + } + _ => Err(FhevmError::BadInputs.into()), + }, + None => Err(FhevmError::BadInputs.into()), + }) + .collect(); + if let Ok(mut inputs) = inputs { + let n = self.add_node(computation, std::mem::take(&mut inputs))?; + self.produced_handles.insert( + computation + .result_handles + .first() + .filter(|h| h.len() == HANDLE_LEN) + .ok_or(SyncComputeError::BadResultHandles)?, + n, + ); + } + } + // Traverse nodes and add dependences/edges as required + for index in 0..self.graph.node_count() { + let take_inputs = std::mem::take( + &mut self + .graph + .node_weight_mut(NodeIndex::new(index)) + .unwrap() + .inputs, + ); + for (idx, input) in take_inputs.iter().enumerate() { + match input { + DFGTaskInput::Handle(input) => { + if let Some(producer_index) = self.produced_handles.get(input) { + self.add_dependence(*producer_index, NodeIndex::new(index), idx as u8)?; + } + } + DFGTaskInput::Val(_) => {} + }; + } + self.graph + .node_weight_mut(NodeIndex::new(index)) + .unwrap() + .inputs = take_inputs; + } + + Ok(()) + } + + pub fn get_results(&mut self) -> Result, SyncComputeError> { + let mut res = Vec::with_capacity(self.graph.node_count()); + for index in 0..self.graph.node_count() { + let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap(); + if let Some(imc) = &node.result { + res.push(CompressedCiphertext { + handle: node.result_handle.clone(), + serialization: imc.compressed.clone(), + }); + } else { + return Err(SyncComputeError::ComputationFailed); + } + } + Ok(res) + } +} diff --git a/fhevm-engine/executor/src/dfg/scheduler.rs b/fhevm-engine/executor/src/dfg/scheduler.rs new file mode 100644 index 00000000..d92e43bf --- /dev/null +++ b/fhevm-engine/executor/src/dfg/scheduler.rs @@ -0,0 +1,83 @@ +use crate::dfg::{types::DFGTaskInput, Edge, Node}; +use crate::server::{run_computation, InMemoryCiphertext, SyncComputeError}; +use anyhow::Result; +use fhevm_engine_common::types::SupportedFheCiphertexts; + +use daggy::{ + petgraph::{ + visit::{EdgeRef, IntoEdgesDirected}, + Direction, + }, + Dag, NodeIndex, +}; +use tokio::task::JoinSet; + +pub struct Scheduler<'a, 'b> { + graph: &'b mut Dag, Edge>, + edges: Dag<(), Edge>, + set: JoinSet>, +} + +impl<'a, 'b> Scheduler<'a, 'b> { + fn is_ready(node: &Node<'a>) -> bool { + let mut ready = true; + for i in node.inputs.iter() { + if let DFGTaskInput::Handle(_) = i { + ready = false; + } + } + ready + } + pub fn new(graph: &'b mut Dag, Edge>) -> Self { + let mut set = JoinSet::new(); + for idx in 0..graph.node_count() { + let index = NodeIndex::new(idx); + let node = graph.node_weight_mut(index).unwrap(); + if Self::is_ready(node) { + let opc = node.computation.operation; + let inputs: Result, SyncComputeError> = node + .inputs + .iter() + .map(|i| match i { + DFGTaskInput::Val(i) => Ok(i.clone()), + DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), + }) + .collect(); + set.spawn_blocking(move || run_computation(opc, inputs, idx)); + } + } + + let edges = graph.map(|_, _| (), |_, edge| *edge); + + Self { graph, edges, set } + } + pub async fn schedule(&mut self) -> Result<(), SyncComputeError> { + while let Some(result) = self.set.join_next().await { + let output = result.map_err(|_| SyncComputeError::ComputationFailed)??; + let index = output.0; + let node_index = NodeIndex::new(index); + // Satisfy deps from the executed task + for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { + let child_index = edge.target(); + let child_node = self.graph.node_weight_mut(child_index).unwrap(); + child_node.inputs[*edge.weight() as usize] = + DFGTaskInput::Val(output.1.expanded.clone()); + if Self::is_ready(child_node) { + let opc = child_node.computation.operation; + let inputs: Result, SyncComputeError> = child_node + .inputs + .iter() + .map(|i| match i { + DFGTaskInput::Val(i) => Ok(i.clone()), + DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), + }) + .collect(); + self.set + .spawn_blocking(move || run_computation(opc, inputs, child_index.index())); + } + } + self.graph.node_weight_mut(node_index).unwrap().result = Some(output.1); + } + Ok(()) + } +} diff --git a/fhevm-engine/executor/src/dfg/types.rs b/fhevm-engine/executor/src/dfg/types.rs new file mode 100644 index 00000000..123a25e1 --- /dev/null +++ b/fhevm-engine/executor/src/dfg/types.rs @@ -0,0 +1,11 @@ +use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts}; + +use crate::server::InMemoryCiphertext; + +pub type DFGTaskResult = Option; + +#[derive(Clone)] +pub enum DFGTaskInput { + Val(SupportedFheCiphertexts), + Handle(Handle), +} diff --git a/fhevm-engine/executor/src/lib.rs b/fhevm-engine/executor/src/lib.rs index 5a9b926e..36a1efd1 100644 --- a/fhevm-engine/executor/src/lib.rs +++ b/fhevm-engine/executor/src/lib.rs @@ -1,2 +1,3 @@ pub mod cli; +pub mod dfg; pub mod server; diff --git a/fhevm-engine/executor/src/main.rs b/fhevm-engine/executor/src/main.rs index be804a69..5954f42d 100644 --- a/fhevm-engine/executor/src/main.rs +++ b/fhevm-engine/executor/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Result; mod cli; +mod dfg; mod server; fn main() -> Result<()> { diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 66b0c8af..629f1906 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -1,13 +1,14 @@ use std::{cell::Cell, collections::HashMap, sync::Arc}; use anyhow::Result; -use common::FheOperation; +pub use common::FheOperation; use executor::{ fhevm_executor_server::{FhevmExecutor, FhevmExecutorServer}, sync_compute_response::Resp, - sync_input::Input, - CompressedCiphertext, ResultCiphertexts, SyncComputation, SyncComputeError, SyncComputeRequest, - SyncComputeResponse, SyncInput, + ResultCiphertexts, SyncComputeResponse, SyncInput, +}; +pub use executor::{ + sync_input::Input, CompressedCiphertext, SyncComputation, SyncComputeError, SyncComputeRequest, }; use fhevm_engine_common::{ keys::{FhevmKeys, SerializedFhevmKeys}, @@ -19,6 +20,8 @@ use tfhe::{integer::U256, set_server_key}; use tokio::task::spawn_blocking; use tonic::{transport::Server, Code, Request, Response, Status}; +use crate::dfg::{scheduler::Scheduler, DFGraph}; + pub mod common { tonic::include_proto!("fhevm.common"); } @@ -28,13 +31,23 @@ pub mod executor { } pub fn start(args: &crate::cli::Args) -> Result<()> { + let keys: Arc = Arc::new(SerializedFhevmKeys::load_from_disk().into()); + let executor = FhevmExecutorService::new(keys.clone()); let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(args.tokio_threads) .max_blocking_threads(args.fhe_compute_threads) + .on_thread_start(move || { + thread_local! { + static SERVER_KEY_IS_SET: Cell = const {Cell::new(false)}; + } + if !SERVER_KEY_IS_SET.get() { + set_server_key(keys.server_key.clone()); + SERVER_KEY_IS_SET.set(true); + } + }) .enable_all() .build()?; - let executor = FhevmExecutorService::new(); let addr = args.server_addr.parse().expect("server address"); runtime.block_on(async { @@ -47,14 +60,14 @@ pub fn start(args: &crate::cli::Args) -> Result<()> { Ok(()) } -struct InMemoryCiphertext { - expanded: SupportedFheCiphertexts, - compressed: Vec, +pub struct InMemoryCiphertext { + pub expanded: SupportedFheCiphertexts, + pub compressed: Vec, } #[derive(Default)] -struct ComputationState { - ciphertexts: HashMap, +pub struct ComputationState { + pub ciphertexts: HashMap, } struct FhevmExecutorService { @@ -69,15 +82,6 @@ impl FhevmExecutor for FhevmExecutorService { ) -> Result, Status> { let keys = self.keys.clone(); let resp = spawn_blocking(move || { - // Make sure we only clone the server key if needed. - thread_local! { - static SERVER_KEY_IS_SET: Cell = Cell::new(false); - } - if !SERVER_KEY_IS_SET.get() { - set_server_key(keys.server_key.clone()); - SERVER_KEY_IS_SET.set(true); - } - let req = req.get_ref(); let mut state = ComputationState::default(); @@ -98,25 +102,29 @@ impl FhevmExecutor for FhevmExecutorService { }; } - // Execute all computations. - let mut result_cts = Vec::new(); - for computation in &req.computations { - let outcome = Self::process_computation(computation, &mut state); - // Either all succeed or we return on the first failure. - match outcome { - Ok(cts) => result_cts.extend(cts), - Err(e) => { - return SyncComputeResponse { - resp: Some(Resp::Error(e.into())), - }; - } + // Run the request's computations in an async block + let handle = tokio::runtime::Handle::current(); + let _ = handle.enter(); + let resp = handle.block_on(async { + // Build the dataflow graph for this request + let mut graph = DFGraph::default(); + if let Err(e) = graph.build_from_request(req, &state) { + return Some(Resp::Error((e as SyncComputeError).into())); } - } - SyncComputeResponse { - resp: Some(Resp::ResultCiphertexts(ResultCiphertexts { - ciphertexts: result_cts, - })), - } + // Schedule computations in parallel as dependences allow + let mut sched = Scheduler::new(&mut graph.graph); + if sched.schedule().await.is_err() { + return Some(Resp::Error(SyncComputeError::ComputationFailed.into())); + } + // Extract the results from the graph + match graph.get_results() { + Ok(result_cts) => Some(Resp::ResultCiphertexts(ResultCiphertexts { + ciphertexts: result_cts, + })), + Err(e) => Some(Resp::Error(e.into())), + } + }); + SyncComputeResponse { resp } }) .await; match resp { @@ -130,10 +138,8 @@ impl FhevmExecutor for FhevmExecutorService { } impl FhevmExecutorService { - fn new() -> Self { - FhevmExecutorService { - keys: Arc::new(SerializedFhevmKeys::load_from_disk().into()), - } + fn new(keys: Arc) -> Self { + FhevmExecutorService { keys } } fn process_computation( @@ -278,3 +284,38 @@ impl FhevmExecutorService { } } } + +pub fn run_computation( + operation: i32, + inputs: Result, SyncComputeError>, + graph_node_index: usize, +) -> Result<(usize, InMemoryCiphertext), SyncComputeError> { + let op = FheOperation::try_from(operation); + match inputs { + Ok(inputs) => match op { + Ok(FheOperation::FheGetCiphertext) => { + let res = InMemoryCiphertext { + expanded: inputs[0].clone(), + compressed: inputs[0].clone().compress(), + }; + Ok((graph_node_index, res)) + } + Ok(_) => match perform_fhe_operation(operation as i16, &inputs) { + Ok(result) => { + let res = InMemoryCiphertext { + expanded: result.clone(), + compressed: result.compress(), + }; + Ok((graph_node_index, res)) + } + Err(_) => Err::<(usize, InMemoryCiphertext), SyncComputeError>( + SyncComputeError::ComputationFailed, + ), + }, + _ => Err::<(usize, InMemoryCiphertext), SyncComputeError>( + SyncComputeError::InvalidOperation, + ), + }, + Err(_) => Err(SyncComputeError::ComputationFailed), + } +} diff --git a/fhevm-engine/executor/tests/sync_compute.rs b/fhevm-engine/executor/tests/sync_compute.rs index d6355739..b52619f2 100644 --- a/fhevm-engine/executor/tests/sync_compute.rs +++ b/fhevm-engine/executor/tests/sync_compute.rs @@ -5,6 +5,7 @@ use executor::server::executor::{ fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest, }; use executor::server::executor::{sync_input::Input, SyncInput}; +use executor::server::SyncComputeError; use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN}; use tfhe::CompactCiphertextListBuilder; use utils::get_test; @@ -244,3 +245,536 @@ async fn compute_on_result_ciphertext() { Resp::Error(e) => assert!(false, "error response: {}", e), } } + +#[tokio::test] +async fn schedule_dependent_computations() { + let test = get_test().await; + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key); + let list = builder + .push(3_u16) + .push(5_u16) + .push(7_u16) + .push(11_u16) + .push(13_u16) + .build(); + let expander = list.expand_with_key(&test.keys.server_key).unwrap(); + let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); + let ct2 = test.compress(ct2); + let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); + let ct3 = test.compress(ct3); + let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); + let ct4 = test.compress(ct4); + let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); + let ct5 = test.compress(ct5); + let handle1 = test.ciphertext_handle(&ct1, 3); + let sync_input1 = SyncInput { + input: Some(Input::Handle(handle1.clone())), + }; + let handle2 = test.ciphertext_handle(&ct2, 3); + let sync_input2 = SyncInput { + input: Some(Input::Handle(handle2.clone())), + }; + let handle3 = test.ciphertext_handle(&ct3, 3); + let sync_input3 = SyncInput { + input: Some(Input::Handle(handle3.clone())), + }; + let handle4 = test.ciphertext_handle(&ct4, 3); + let sync_input4 = SyncInput { + input: Some(Input::Handle(handle4.clone())), + }; + let handle5 = test.ciphertext_handle(&ct5, 3); + let sync_input5 = SyncInput { + input: Some(Input::Handle(handle5.clone())), + }; + let sync_input6 = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input7 = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input8 = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + + let computation1 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1, sync_input2], + }; + let computation2 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input3, sync_input4], + }; + let computation3 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input6, sync_input7], + }; + let computation4 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xdd; HANDLE_LEN]], + inputs: vec![sync_input5, sync_input8], + }; + + let req = SyncComputeRequest { + computations: vec![computation4, computation3, computation2, computation1], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle1, + serialization: ct1, + }, + CompressedCiphertext { + handle: handle2, + serialization: ct2, + }, + CompressedCiphertext { + handle: handle3, + serialization: ct3, + }, + CompressedCiphertext { + handle: handle4, + serialization: ct4, + }, + CompressedCiphertext { + handle: handle5, + serialization: ct5, + }, + ], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == 4, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + 4 + ); + let aa: Vec = vec![0xaa; HANDLE_LEN]; + let bb: Vec = vec![0xbb; HANDLE_LEN]; + let cc: Vec = vec![0xcc; HANDLE_LEN]; + let dd: Vec = vec![0xdd; HANDLE_LEN]; + for ct in cts.ciphertexts.iter() { + match &ct.handle { + a if *a == aa => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "8" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + b if *b == bb => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "18" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + c if *c == cc => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "26" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + d if *d == dd => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "39" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} + +#[tokio::test] +async fn schedule_circular_dependence() { + let test = get_test().await; + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let sync_input1 = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input2 = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input3 = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + + let computation1 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input1], + }; + let computation2 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input2], + }; + let computation3 = SyncComputation { + operation: FheOperation::FheNeg.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input3], + }; + + let req = SyncComputeRequest { + computations: vec![computation1, computation2, computation3], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(_cts) => assert!( + false, + "Received ciphertext outputs despite circular dependence." + ), + Resp::Error(e) => assert!( + e == SyncComputeError::UnsatisfiedDependence as i32, + "Error response should be UnsatisfiedDependence but is {}", + e + ), + } +} + +#[tokio::test] +async fn schedule_diamond_reduction_dependence_pattern() { + let test = get_test().await; + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key); + let list = builder + .push(1_u16) + .push(2_u16) + .push(3_u16) + .push(4_u16) + .push(5_u16) + .build(); + let expander = list.expand_with_key(&test.keys.server_key).unwrap(); + let ct1 = SupportedFheCiphertexts::FheUint16(expander.get(0).unwrap().unwrap()); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16(expander.get(1).unwrap().unwrap()); + let ct2 = test.compress(ct2); + let ct3 = SupportedFheCiphertexts::FheUint16(expander.get(2).unwrap().unwrap()); + let ct3 = test.compress(ct3); + let ct4 = SupportedFheCiphertexts::FheUint16(expander.get(3).unwrap().unwrap()); + let ct4 = test.compress(ct4); + let ct5 = SupportedFheCiphertexts::FheUint16(expander.get(4).unwrap().unwrap()); + let ct5 = test.compress(ct5); + let handle1 = test.ciphertext_handle(&ct1, 3); + let sync_input1 = SyncInput { + input: Some(Input::Handle(handle1.clone())), + }; + let handle2 = test.ciphertext_handle(&ct2, 3); + let sync_input2 = SyncInput { + input: Some(Input::Handle(handle2.clone())), + }; + let handle3 = test.ciphertext_handle(&ct3, 3); + let sync_input3 = SyncInput { + input: Some(Input::Handle(handle3.clone())), + }; + let handle4 = test.ciphertext_handle(&ct4, 3); + let sync_input4 = SyncInput { + input: Some(Input::Handle(handle4.clone())), + }; + let handle5 = test.ciphertext_handle(&ct5, 3); + let sync_input5 = SyncInput { + input: Some(Input::Handle(handle5.clone())), + }; + let sync_input_aa = SyncInput { + input: Some(Input::Handle(vec![0xaa; HANDLE_LEN])), + }; + let sync_input_bb = SyncInput { + input: Some(Input::Handle(vec![0xbb; HANDLE_LEN])), + }; + let sync_input_cc = SyncInput { + input: Some(Input::Handle(vec![0xcc; HANDLE_LEN])), + }; + let sync_input_dd = SyncInput { + input: Some(Input::Handle(vec![0xdd; HANDLE_LEN])), + }; + let sync_input_ee = SyncInput { + input: Some(Input::Handle(vec![0xee; HANDLE_LEN])), + }; + let sync_input_ff = SyncInput { + input: Some(Input::Handle(vec![0xff; HANDLE_LEN])), + }; + let sync_input_99 = SyncInput { + input: Some(Input::Handle(vec![0x99; HANDLE_LEN])), + }; + + let computation1 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1.clone(), sync_input1], + }; // Compute 1 + 1 + let computation2 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xbb; HANDLE_LEN]], + inputs: vec![sync_input2, sync_input_aa.clone()], + }; // 2 + 2 + let computation3 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xcc; HANDLE_LEN]], + inputs: vec![sync_input3, sync_input_aa.clone()], + }; // 2 + 3 + let computation4 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xdd; HANDLE_LEN]], + inputs: vec![sync_input4, sync_input_aa.clone()], + }; // 2 + 4 + let computation5 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xee; HANDLE_LEN]], + inputs: vec![sync_input5, sync_input_aa.clone()], + }; // 2 + 5 + + let computation6 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xff; HANDLE_LEN]], + inputs: vec![sync_input_bb, sync_input_cc], + }; // 4 + 5 + let computation7 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x99; HANDLE_LEN]], + inputs: vec![sync_input_dd, sync_input_ee], + }; // 6 + 7 + let computation8 = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0x88; HANDLE_LEN]], + inputs: vec![sync_input_ff, sync_input_99], + }; // 9 + 13 + + let req = SyncComputeRequest { + computations: vec![ + computation4, + computation3, + computation2, + computation1, + computation5, + computation6, + computation7, + computation8, + ], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![ + CompressedCiphertext { + handle: handle1, + serialization: ct1, + }, + CompressedCiphertext { + handle: handle2, + serialization: ct2, + }, + CompressedCiphertext { + handle: handle3, + serialization: ct3, + }, + CompressedCiphertext { + handle: handle4, + serialization: ct4, + }, + CompressedCiphertext { + handle: handle5, + serialization: ct5, + }, + ], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => { + assert!( + cts.ciphertexts.len() == 8, + "wrong number of output ciphertexts {} instead of {}", + cts.ciphertexts.len(), + 8 + ); + let aa: Vec = vec![0xaa; HANDLE_LEN]; + let bb: Vec = vec![0xbb; HANDLE_LEN]; + let cc: Vec = vec![0xcc; HANDLE_LEN]; + let dd: Vec = vec![0xdd; HANDLE_LEN]; + let ee: Vec = vec![0xee; HANDLE_LEN]; + let ff: Vec = vec![0xff; HANDLE_LEN]; + let x88: Vec = vec![0x88; HANDLE_LEN]; + let x99: Vec = vec![0x99; HANDLE_LEN]; + for ct in cts.ciphertexts.iter() { + match &ct.handle { + a if *a == aa => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "2" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + b if *b == bb => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "4" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + c if *c == cc => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "5" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + d if *d == dd => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "6" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + e if *e == ee => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "7" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + f if *f == ff => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "9" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x99 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "13" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + x if *x == x88 => { + let ctd = + SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ctd + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "22" => (), + s => assert!( + false, + "unexpected result: {} for handle 0x{:x}", + s, ct.handle[0] + ), + } + } + _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + } + } + } + Resp::Error(e) => assert!(false, "error response: {}", e), + } +} diff --git a/proto/executor.proto b/proto/executor.proto index 76f31bce..54a89067 100644 --- a/proto/executor.proto +++ b/proto/executor.proto @@ -58,6 +58,7 @@ enum SyncComputeError { UNKNOWN_HANDLE = 5; COMPUTATION_FAILED = 6; BAD_RESULT_HANDLES = 7; + UNSATISFIED_DEPENDENCE = 8; } message CompressedCiphertext {