-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(executor): add a parallel dataflow scheduler for computations in…
… a SyncComputeRequest The scheduler runs on top of a DFG built using the daggy extension to petgraph and executes computations/tasks along the wavefront of maximum parallelism within the graph. Circular dependences are detected early.
- Loading branch information
1 parent
9ac06fd
commit 5080aeb
Showing
11 changed files
with
874 additions
and
41 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
pub mod scheduler; | ||
mod types; | ||
|
||
use crate::dfg::types::*; | ||
use crate::server::{ | ||
CompressedCiphertext, ComputationState, Input, SyncComputation, SyncComputeError, | ||
SyncComputeRequest, | ||
}; | ||
use anyhow::Result; | ||
use fhevm_engine_common::types::{ | ||
FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN, | ||
}; | ||
use tfhe::integer::U256; | ||
|
||
use daggy::{Dag, NodeIndex}; | ||
use std::collections::HashMap; | ||
|
||
//TODO#[derive(Debug)] | ||
pub struct Node<'a> { | ||
computation: &'a SyncComputation, | ||
result: DFGTaskResult, | ||
result_handle: Handle, | ||
inputs: Vec<DFGTaskInput>, | ||
} | ||
pub type Edge = u8; | ||
|
||
//TODO#[derive(Debug)] | ||
#[derive(Default)] | ||
pub struct DFGraph<'a> { | ||
pub graph: Dag<Node<'a>, Edge>, | ||
produced_handles: HashMap<&'a Handle, NodeIndex>, | ||
} | ||
|
||
impl<'a> DFGraph<'a> { | ||
pub fn add_node( | ||
&mut self, | ||
computation: &'a SyncComputation, | ||
inputs: Vec<DFGTaskInput>, | ||
) -> Result<NodeIndex, SyncComputeError> { | ||
let rh = computation | ||
.result_handles | ||
.first() | ||
.filter(|h| h.len() == HANDLE_LEN) | ||
.ok_or(SyncComputeError::BadResultHandles)?; | ||
Ok(self.graph.add_node(Node { | ||
computation, | ||
result: None, | ||
result_handle: rh.clone(), | ||
inputs, | ||
})) | ||
} | ||
|
||
pub fn add_dependence( | ||
&mut self, | ||
source: NodeIndex, | ||
destination: NodeIndex, | ||
consumer_input: Edge, | ||
) -> Result<(), SyncComputeError> { | ||
let _edge = self | ||
.graph | ||
.add_edge(source, destination, consumer_input) | ||
.map_err(|_| SyncComputeError::UnsatisfiedDependence)?; | ||
Ok(()) | ||
} | ||
pub fn build_from_request( | ||
&mut self, | ||
req: &'a SyncComputeRequest, | ||
state: &ComputationState, | ||
) -> Result<(), SyncComputeError> { | ||
// Add all computations as nodes in the graph. | ||
for computation in &req.computations { | ||
let inputs: Result<Vec<DFGTaskInput>> = computation | ||
.inputs | ||
.iter() | ||
.map(|input| match &input.input { | ||
Some(input) => match input { | ||
Input::Handle(h) => { | ||
if let Some(ct) = state.ciphertexts.get(h) { | ||
Ok(DFGTaskInput::Val(ct.expanded.clone())) | ||
} else { | ||
Ok(DFGTaskInput::Handle(h.clone())) | ||
} | ||
} | ||
Input::Scalar(s) if s.len() == SCALAR_LEN => { | ||
let mut scalar = U256::default(); | ||
scalar.copy_from_be_byte_slice(s); | ||
Ok(DFGTaskInput::Val(SupportedFheCiphertexts::Scalar(scalar))) | ||
} | ||
_ => Err(FhevmError::BadInputs.into()), | ||
}, | ||
None => Err(FhevmError::BadInputs.into()), | ||
}) | ||
.collect(); | ||
if let Ok(mut inputs) = inputs { | ||
let n = self.add_node(computation, std::mem::take(&mut inputs))?; | ||
self.produced_handles.insert( | ||
computation | ||
.result_handles | ||
.first() | ||
.filter(|h| h.len() == HANDLE_LEN) | ||
.ok_or(SyncComputeError::BadResultHandles)?, | ||
n, | ||
); | ||
} | ||
} | ||
// Traverse nodes and add dependences/edges as required | ||
for index in 0..self.graph.node_count() { | ||
let take_inputs = std::mem::take( | ||
&mut self | ||
.graph | ||
.node_weight_mut(NodeIndex::new(index)) | ||
.unwrap() | ||
.inputs, | ||
); | ||
for (idx, input) in take_inputs.iter().enumerate() { | ||
match input { | ||
DFGTaskInput::Handle(input) => { | ||
if let Some(producer_index) = self.produced_handles.get(input) { | ||
self.add_dependence(*producer_index, NodeIndex::new(index), idx as u8)?; | ||
} | ||
} | ||
DFGTaskInput::Val(_) => {} | ||
}; | ||
} | ||
self.graph | ||
.node_weight_mut(NodeIndex::new(index)) | ||
.unwrap() | ||
.inputs = take_inputs; | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
pub fn get_results(&mut self) -> Result<Vec<CompressedCiphertext>, SyncComputeError> { | ||
let mut res = Vec::with_capacity(self.graph.node_count()); | ||
for index in 0..self.graph.node_count() { | ||
let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap(); | ||
if let Some(imc) = &node.result { | ||
res.push(CompressedCiphertext { | ||
handle: node.result_handle.clone(), | ||
serialization: imc.compressed.clone(), | ||
}); | ||
} else { | ||
return Err(SyncComputeError::ComputationFailed); | ||
} | ||
} | ||
Ok(res) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use crate::dfg::{types::DFGTaskInput, Edge, Node}; | ||
use crate::server::{run_computation, InMemoryCiphertext, SyncComputeError}; | ||
use anyhow::Result; | ||
use fhevm_engine_common::types::SupportedFheCiphertexts; | ||
|
||
use daggy::{ | ||
petgraph::{ | ||
visit::{EdgeRef, IntoEdgesDirected}, | ||
Direction, | ||
}, | ||
Dag, NodeIndex, | ||
}; | ||
use tokio::task::JoinSet; | ||
|
||
pub struct Scheduler<'a, 'b> { | ||
graph: &'b mut Dag<Node<'a>, Edge>, | ||
edges: Dag<(), Edge>, | ||
set: JoinSet<Result<(usize, InMemoryCiphertext), SyncComputeError>>, | ||
} | ||
|
||
impl<'a, 'b> Scheduler<'a, 'b> { | ||
fn is_ready(node: &Node<'a>) -> bool { | ||
let mut ready = true; | ||
for i in node.inputs.iter() { | ||
if let DFGTaskInput::Handle(_) = i { | ||
ready = false; | ||
} | ||
} | ||
ready | ||
} | ||
pub fn new(graph: &'b mut Dag<Node<'a>, Edge>) -> Self { | ||
let mut set = JoinSet::new(); | ||
for idx in 0..graph.node_count() { | ||
let index = NodeIndex::new(idx); | ||
let node = graph.node_weight_mut(index).unwrap(); | ||
if Self::is_ready(node) { | ||
let opc = node.computation.operation; | ||
let inputs: Result<Vec<SupportedFheCiphertexts>, SyncComputeError> = node | ||
.inputs | ||
.iter() | ||
.map(|i| match i { | ||
DFGTaskInput::Val(i) => Ok(i.clone()), | ||
DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), | ||
}) | ||
.collect(); | ||
set.spawn_blocking(move || run_computation(opc, inputs, idx)); | ||
} | ||
} | ||
|
||
let edges = graph.map(|_, _| (), |_, edge| *edge); | ||
|
||
Self { graph, edges, set } | ||
} | ||
pub async fn schedule(&mut self) -> Result<(), SyncComputeError> { | ||
while let Some(result) = self.set.join_next().await { | ||
let output = result.map_err(|_| SyncComputeError::ComputationFailed)??; | ||
let index = output.0; | ||
let node_index = NodeIndex::new(index); | ||
// Satisfy deps from the executed task | ||
for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { | ||
let child_index = edge.target(); | ||
let child_node = self.graph.node_weight_mut(child_index).unwrap(); | ||
child_node.inputs[*edge.weight() as usize] = | ||
DFGTaskInput::Val(output.1.expanded.clone()); | ||
if Self::is_ready(child_node) { | ||
let opc = child_node.computation.operation; | ||
let inputs: Result<Vec<SupportedFheCiphertexts>, SyncComputeError> = child_node | ||
.inputs | ||
.iter() | ||
.map(|i| match i { | ||
DFGTaskInput::Val(i) => Ok(i.clone()), | ||
DFGTaskInput::Handle(_) => Err(SyncComputeError::ComputationFailed), | ||
}) | ||
.collect(); | ||
self.set | ||
.spawn_blocking(move || run_computation(opc, inputs, child_index.index())); | ||
} | ||
} | ||
self.graph.node_weight_mut(node_index).unwrap().result = Some(output.1); | ||
} | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts}; | ||
|
||
use crate::server::InMemoryCiphertext; | ||
|
||
pub type DFGTaskResult = Option<InMemoryCiphertext>; | ||
|
||
#[derive(Clone)] | ||
pub enum DFGTaskInput { | ||
Val(SupportedFheCiphertexts), | ||
Handle(Handle), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pub mod cli; | ||
pub mod dfg; | ||
pub mod server; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
use anyhow::Result; | ||
|
||
mod cli; | ||
mod dfg; | ||
mod server; | ||
|
||
fn main() -> Result<()> { | ||
|
Oops, something went wrong.