Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add FHE computation to the executor #14

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fhevm-engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions fhevm-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ prost = "0.13"
tonic = { version = "0.12", features = ["server"] }
bincode = "1.3.3"
sha3 = "0.10.8"
anyhow = "1.0.86"

1 change: 1 addition & 0 deletions fhevm-engine/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tonic.workspace = true
tfhe.workspace = true
bincode.workspace = true
sha3.workspace = true
anyhow.workspace = true
fhevm-engine-common = { path = "../fhevm-engine-common" }

[build-dependencies]
Expand Down
137 changes: 96 additions & 41 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use executor::{
};
use fhevm_engine_common::{
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list},
types::{FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
};
use sha3::{Digest, Keccak256};
use tfhe::set_server_key;
use tfhe::{integer::U256, set_server_key};
use tokio::task::spawn_blocking;
use tonic::{transport::Server, Code, Request, Response, Status};

Expand Down Expand Up @@ -56,20 +56,6 @@ struct ComputationState {
ciphertexts: HashMap<Handle, InMemoryCiphertext>,
}

fn error_response(error: SyncComputeError) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::Error(error.into())),
}
}

fn success_response(cts: Vec<Ciphertext>) -> SyncComputeResponse {
SyncComputeResponse {
resp: Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: cts,
})),
}
}

struct FhevmExecutorService {
keys: Arc<FhevmKeys>,
}
Expand All @@ -95,24 +81,30 @@ impl FhevmExecutor for FhevmExecutorService {
let req = req.get_ref();
let mut state = ComputationState::default();
if Self::expand_inputs(&req.input_lists, &keys, &mut state).is_err() {
return error_response(SyncComputeError::BadInputList);
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputList.into())),
};
}

// Execute all computations.
let mut result_cts = Vec::new();
for computation in &req.computations {
let outcome = Self::process_computation(computation, &mut state);
// Either all succeed or we return on the first failure.
match outcome.resp.unwrap() {
Resp::Error(error) => {
return error_response(
SyncComputeError::try_from(error).expect("correct error value"),
);
match outcome {
Ok(cts) => result_cts.extend(cts),
Err(e) => {
return SyncComputeResponse {
resp: Some(Resp::Error(e.into())),
};
}
Resp::ResultCiphertexts(cts) => result_cts.extend(cts.ciphertexts),
}
}
success_response(result_cts)
SyncComputeResponse {
resp: Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: result_cts,
})),
}
})
.await;
match resp {
Expand All @@ -135,12 +127,21 @@ impl FhevmExecutorService {
fn process_computation(
comp: &SyncComputation,
state: &mut ComputationState,
) -> SyncComputeResponse {
) -> Result<Vec<Ciphertext>, SyncComputeError> {
// For now, assume only one result handle.
let result_handle = comp
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or_else(|| SyncComputeError::BadResultHandles)?
.clone();
let op = FheOperation::try_from(comp.operation);
match op {
Ok(FheOperation::FheGetInputCiphertext) => Self::get_input_ciphertext(comp, &state),
Ok(_) => error_response(SyncComputeError::UnsupportedOperation),
_ => error_response(SyncComputeError::InvalidOperation),
Ok(FheOperation::FheGetCiphertext) => {
Self::get_ciphertext(comp, &result_handle, &state)
}
Ok(_) => Self::compute(comp, result_handle, state),
_ => Err(SyncComputeError::InvalidOperation),
}
}

Expand All @@ -151,9 +152,9 @@ impl FhevmExecutorService {
) -> Result<(), FhevmError> {
for list in lists {
let cts = try_expand_ciphertext_list(&list, &keys.server_key)?;
let list_hash: Handle = Keccak256::digest(list).into();
let list_hash: Handle = Keccak256::digest(list).to_vec();
for (i, ct) in cts.iter().enumerate() {
let mut handle = list_hash;
let mut handle = list_hash.clone();
handle[29] = i as u8;
handle[30] = ct.type_num() as u8;
handle[31] = current_ciphertext_version() as u8;
Expand All @@ -169,31 +170,85 @@ impl FhevmExecutorService {
Ok(())
}

fn get_input_ciphertext(
fn get_ciphertext(
comp: &SyncComputation,
result_handle: &Handle,
state: &ComputationState,
) -> SyncComputeResponse {
) -> Result<Vec<Ciphertext>, SyncComputeError> {
match (comp.inputs.first(), comp.inputs.len()) {
(
Some(SyncInput {
input: Some(Input::InputHandle(handle)),
}),
1,
) => {
if let Ok(handle) = (handle as &[u8]).try_into() as Result<Handle, _> {
if let Some(in_mem_ciphertext) = state.ciphertexts.get(&handle) {
success_response(vec![Ciphertext {
handle: handle.to_vec(),
if let Some(in_mem_ciphertext) = state.ciphertexts.get(handle) {
if *handle != *result_handle {
Err(SyncComputeError::BadInputs)
} else {
Ok(vec![Ciphertext {
handle: result_handle.to_vec(),
ciphertext: in_mem_ciphertext.compressed.clone(),
}])
} else {
error_response(SyncComputeError::UnknownHandle)
}
} else {
error_response(SyncComputeError::BadInputs)
Err(SyncComputeError::UnknownHandle)
}
}
_ => error_response(SyncComputeError::BadInputs),
_ => Err(SyncComputeError::BadInputs),
}
}

fn compute(
comp: &SyncComputation,
result_handle: Handle,
state: &mut ComputationState,
) -> Result<Vec<Ciphertext>, SyncComputeError> {
// Collect computation inputs.
let inputs: Result<Vec<SupportedFheCiphertexts>, Box<dyn Error>> = comp
.inputs
.iter()
.map(|sync_input| match &sync_input.input {
Some(input) => match input {
Input::Ciphertext(c) if c.handle.len() == HANDLE_LEN => {
let ct_type = c.handle[30] as i16;
Ok(SupportedFheCiphertexts::decompress(ct_type, &c.ciphertext)?)
}
Input::InputHandle(h) => {
let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?;
Ok(ct.expanded.clone())
}
Input::Scalar(s) if s.len() == SCALAR_LEN => {
let mut scalar = U256::default();
scalar.copy_from_be_byte_slice(&s);
Ok(SupportedFheCiphertexts::Scalar(scalar))
}
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
})
.collect();

// Do the computation on the inputs.
match inputs {
Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) {
Ok(result) => {
let compressed = result.clone().compress();
state.ciphertexts.insert(
result_handle.clone(),
InMemoryCiphertext {
expanded: result,
compressed: compressed.clone(),
},
);
Ok(vec![Ciphertext {
handle: result_handle,
ciphertext: compressed,
}])
}
Err(_) => Err(SyncComputeError::ComputationFailed),
},
Err(_) => Err(SyncComputeError::BadInputs),
}
}
}
93 changes: 77 additions & 16 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
use anyhow::{anyhow, Result};
use executor::server::common::FheOperation;
use executor::server::executor::sync_compute_response::Resp;
use executor::server::executor::Ciphertext;
use executor::server::executor::{
fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest,
};
use executor::server::executor::{sync_input::Input, SyncInput};
use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN};
use tfhe::CompactCiphertextListBuilder;
use utils::get_test;

mod utils;

#[tokio::test]
async fn get_input_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
async fn get_input_ciphertext() -> Result<()> {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = bincode::serialize(&builder.push(10_u8).build()).unwrap();
let list = bincode::serialize(&builder.push(10_u8).build())?;
// TODO: tests for all types and avoiding passing in 2 as an identifier for FheUint8.
let input_handle = test.input_handle(&list, 0, 2);
let sync_input = SyncInput {
input: Some(Input::InputHandle(input_handle.to_vec())),
input: Some(Input::InputHandle(input_handle.clone())),
};
let computation = SyncComputation {
operation: FheOperation::FheGetInputCiphertext.into(),
result_handles: vec![vec![0xaa]],
operation: FheOperation::FheGetCiphertext.into(),
result_handles: vec![input_handle.clone()],
inputs: vec![sync_input],
};
let req = SyncComputeRequest {
Expand All @@ -31,18 +34,76 @@ async fn get_input_ciphertexts() -> Result<(), Box<dyn std::error::Error>> {
};
let response = client.sync_compute(req).await?;
let sync_compute_response = response.get_ref();
match &sync_compute_response.resp {
Some(Resp::ResultCiphertexts(cts)) => {
match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
assert!(false);
}
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
}
_ => assert!(false),
Ok(())
}
}
_ => assert!(false),
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
}
}

#[tokio::test]
async fn fhe_compute_two_ciphertexts() -> Result<()> {
let test = get_test().await;
let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?;
let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key);
let list = builder.push(10_u16).push(11_u16).build();
let expander = list.expand_with_key(&test.keys.server_key)?;
let ct1 = SupportedFheCiphertexts::FheUint16(
expander
.get(0)
.ok_or(anyhow!("missing ciphertext at index 0"))??,
);
let ct1 = test.compress(ct1);
let ct2 = SupportedFheCiphertexts::FheUint16(
expander
.get(1)
.ok_or(anyhow!("missing ciphertext at index 1"))??,
);
let ct2 = test.compress(ct2);
let sync_input1 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct1, 3).to_vec(),
ciphertext: ct1,
})),
};
let sync_input2 = SyncInput {
input: Some(Input::Ciphertext(Ciphertext {
handle: test.ciphertext_handle(&ct2, 3).to_vec(),
ciphertext: ct2,
})),
};
let computation = SyncComputation {
operation: FheOperation::FheAdd.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation],
input_lists: vec![],
};
let response = client.sync_compute(req).await?;
let sync_compute_response = response.get_ref();
let resp = <Option<Resp> as Clone>::clone(&sync_compute_response.resp)
.ok_or_else(|| anyhow!("resp is None"))?;
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] || ct.ciphertext.is_empty() {
return Err(anyhow!("response handle or ciphertext are unexpected"));
}
Ok(())
}
_ => Err(anyhow!("unexpected amount of result ciphertexts returned")),
},
Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))),
}
Ok(())
}
Loading