From a1913afa9f732c4a9bf1fab5f3c17a875e1abde3 Mon Sep 17 00:00:00 2001 From: Petar Ivanov <29689712+dartdart26@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:42:51 +0300 Subject: [PATCH] feat: add FHE computation to the executor Error handling is still rough, can be simplified further. --- fhevm-engine/Cargo.lock | 3 + fhevm-engine/Cargo.toml | 2 + fhevm-engine/executor/Cargo.toml | 1 + fhevm-engine/executor/src/server.rs | 137 ++++++++++++------ fhevm-engine/executor/tests/sync_compute.rs | 93 ++++++++++-- fhevm-engine/executor/tests/utils.rs | 17 ++- fhevm-engine/fhevm-engine-common/Cargo.toml | 8 +- fhevm-engine/fhevm-engine-common/src/types.rs | 44 +++++- proto/common.proto | 2 +- proto/executor.proto | 2 + 10 files changed, 246 insertions(+), 63 deletions(-) diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index 4b56ea2e..6f8f887e 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -862,6 +862,7 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" name = "executor" version = "0.1.0" dependencies = [ + "anyhow", "bincode", "clap", "fhevm-engine-common", @@ -883,8 +884,10 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" name = "fhevm-engine-common" version = "0.1.0" dependencies = [ + "anyhow", "bincode", "hex", + "sha3", "strum", "tfhe", ] diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index 8898bffd..e71f4d7c 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -10,3 +10,5 @@ prost = "0.13" tonic = { version = "0.12", features = ["server"] } bincode = "1.3.3" sha3 = "0.10.8" +anyhow = "1.0.86" + diff --git a/fhevm-engine/executor/Cargo.toml b/fhevm-engine/executor/Cargo.toml index 2b28b30d..e33cda0a 100644 --- a/fhevm-engine/executor/Cargo.toml +++ b/fhevm-engine/executor/Cargo.toml @@ -11,6 +11,7 @@ tonic.workspace = true tfhe.workspace = true bincode.workspace = true sha3.workspace = true +anyhow.workspace = true fhevm-engine-common = { path = "../fhevm-engine-common" } [build-dependencies] diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 88609676..9959ddfd 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -10,11 +10,11 @@ use executor::{ }; use fhevm_engine_common::{ keys::{FhevmKeys, SerializedFhevmKeys}, - tfhe_ops::{current_ciphertext_version, try_expand_ciphertext_list}, - types::{FhevmError, Handle, SupportedFheCiphertexts}, + tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list}, + types::{FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN}, }; use sha3::{Digest, Keccak256}; -use tfhe::set_server_key; +use tfhe::{integer::U256, set_server_key}; use tokio::task::spawn_blocking; use tonic::{transport::Server, Code, Request, Response, Status}; @@ -56,20 +56,6 @@ struct ComputationState { ciphertexts: HashMap, } -fn error_response(error: SyncComputeError) -> SyncComputeResponse { - SyncComputeResponse { - resp: Some(Resp::Error(error.into())), - } -} - -fn success_response(cts: Vec) -> SyncComputeResponse { - SyncComputeResponse { - resp: Some(Resp::ResultCiphertexts(ResultCiphertexts { - ciphertexts: cts, - })), - } -} - struct FhevmExecutorService { keys: Arc, } @@ -95,7 +81,9 @@ impl FhevmExecutor for FhevmExecutorService { let req = req.get_ref(); let mut state = ComputationState::default(); if Self::expand_inputs(&req.input_lists, &keys, &mut state).is_err() { - return error_response(SyncComputeError::BadInputList); + return SyncComputeResponse { + resp: Some(Resp::Error(SyncComputeError::BadInputList.into())), + }; } // Execute all computations. @@ -103,16 +91,20 @@ impl FhevmExecutor for FhevmExecutorService { for computation in &req.computations { let outcome = Self::process_computation(computation, &mut state); // Either all succeed or we return on the first failure. - match outcome.resp.unwrap() { - Resp::Error(error) => { - return error_response( - SyncComputeError::try_from(error).expect("correct error value"), - ); + match outcome { + Ok(cts) => result_cts.extend(cts), + Err(e) => { + return SyncComputeResponse { + resp: Some(Resp::Error(e.into())), + }; } - Resp::ResultCiphertexts(cts) => result_cts.extend(cts.ciphertexts), } } - success_response(result_cts) + SyncComputeResponse { + resp: Some(Resp::ResultCiphertexts(ResultCiphertexts { + ciphertexts: result_cts, + })), + } }) .await; match resp { @@ -135,12 +127,21 @@ impl FhevmExecutorService { fn process_computation( comp: &SyncComputation, state: &mut ComputationState, - ) -> SyncComputeResponse { + ) -> Result, SyncComputeError> { + // For now, assume only one result handle. + let result_handle = comp + .result_handles + .first() + .filter(|h| h.len() == HANDLE_LEN) + .ok_or_else(|| SyncComputeError::BadResultHandles)? + .clone(); let op = FheOperation::try_from(comp.operation); match op { - Ok(FheOperation::FheGetInputCiphertext) => Self::get_input_ciphertext(comp, &state), - Ok(_) => error_response(SyncComputeError::UnsupportedOperation), - _ => error_response(SyncComputeError::InvalidOperation), + Ok(FheOperation::FheGetCiphertext) => { + Self::get_ciphertext(comp, &result_handle, &state) + } + Ok(_) => Self::compute(comp, result_handle, state), + _ => Err(SyncComputeError::InvalidOperation), } } @@ -151,9 +152,9 @@ impl FhevmExecutorService { ) -> Result<(), FhevmError> { for list in lists { let cts = try_expand_ciphertext_list(&list, &keys.server_key)?; - let list_hash: Handle = Keccak256::digest(list).into(); + let list_hash: Handle = Keccak256::digest(list).to_vec(); for (i, ct) in cts.iter().enumerate() { - let mut handle = list_hash; + let mut handle = list_hash.clone(); handle[29] = i as u8; handle[30] = ct.type_num() as u8; handle[31] = current_ciphertext_version() as u8; @@ -169,10 +170,11 @@ impl FhevmExecutorService { Ok(()) } - fn get_input_ciphertext( + fn get_ciphertext( comp: &SyncComputation, + result_handle: &Handle, state: &ComputationState, - ) -> SyncComputeResponse { + ) -> Result, SyncComputeError> { match (comp.inputs.first(), comp.inputs.len()) { ( Some(SyncInput { @@ -180,20 +182,73 @@ impl FhevmExecutorService { }), 1, ) => { - if let Ok(handle) = (handle as &[u8]).try_into() as Result { - if let Some(in_mem_ciphertext) = state.ciphertexts.get(&handle) { - success_response(vec![Ciphertext { - handle: handle.to_vec(), + if let Some(in_mem_ciphertext) = state.ciphertexts.get(handle) { + if *handle != *result_handle { + Err(SyncComputeError::BadInputs) + } else { + Ok(vec![Ciphertext { + handle: result_handle.to_vec(), ciphertext: in_mem_ciphertext.compressed.clone(), }]) - } else { - error_response(SyncComputeError::UnknownHandle) } } else { - error_response(SyncComputeError::BadInputs) + Err(SyncComputeError::UnknownHandle) } } - _ => error_response(SyncComputeError::BadInputs), + _ => Err(SyncComputeError::BadInputs), + } + } + + fn compute( + comp: &SyncComputation, + result_handle: Handle, + state: &mut ComputationState, + ) -> Result, SyncComputeError> { + // Collect computation inputs. + let inputs: Result, Box> = comp + .inputs + .iter() + .map(|sync_input| match &sync_input.input { + Some(input) => match input { + Input::Ciphertext(c) if c.handle.len() == HANDLE_LEN => { + let ct_type = c.handle[30] as i16; + Ok(SupportedFheCiphertexts::decompress(ct_type, &c.ciphertext)?) + } + Input::InputHandle(h) => { + let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?; + Ok(ct.expanded.clone()) + } + Input::Scalar(s) if s.len() == SCALAR_LEN => { + let mut scalar = U256::default(); + scalar.copy_from_be_byte_slice(&s); + Ok(SupportedFheCiphertexts::Scalar(scalar)) + } + _ => Err(FhevmError::BadInputs.into()), + }, + None => Err(FhevmError::BadInputs.into()), + }) + .collect(); + + // Do the computation on the inputs. + match inputs { + Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) { + Ok(result) => { + let compressed = result.clone().compress(); + state.ciphertexts.insert( + result_handle.clone(), + InMemoryCiphertext { + expanded: result, + compressed: compressed.clone(), + }, + ); + Ok(vec![Ciphertext { + handle: result_handle, + ciphertext: compressed, + }]) + } + Err(_) => Err(SyncComputeError::ComputationFailed), + }, + Err(_) => Err(SyncComputeError::BadInputs), } } } diff --git a/fhevm-engine/executor/tests/sync_compute.rs b/fhevm-engine/executor/tests/sync_compute.rs index 61a39e06..efa9ec8b 100644 --- a/fhevm-engine/executor/tests/sync_compute.rs +++ b/fhevm-engine/executor/tests/sync_compute.rs @@ -1,28 +1,31 @@ +use anyhow::{anyhow, Result}; use executor::server::common::FheOperation; use executor::server::executor::sync_compute_response::Resp; +use executor::server::executor::Ciphertext; use executor::server::executor::{ fhevm_executor_client::FhevmExecutorClient, SyncComputation, SyncComputeRequest, }; use executor::server::executor::{sync_input::Input, SyncInput}; +use fhevm_engine_common::types::{SupportedFheCiphertexts, HANDLE_LEN}; use tfhe::CompactCiphertextListBuilder; use utils::get_test; mod utils; #[tokio::test] -async fn get_input_ciphertexts() -> Result<(), Box> { +async fn get_input_ciphertext() -> Result<()> { let test = get_test().await; let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?; let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key); - let list = bincode::serialize(&builder.push(10_u8).build()).unwrap(); + let list = bincode::serialize(&builder.push(10_u8).build())?; // TODO: tests for all types and avoiding passing in 2 as an identifier for FheUint8. let input_handle = test.input_handle(&list, 0, 2); let sync_input = SyncInput { - input: Some(Input::InputHandle(input_handle.to_vec())), + input: Some(Input::InputHandle(input_handle.clone())), }; let computation = SyncComputation { - operation: FheOperation::FheGetInputCiphertext.into(), - result_handles: vec![vec![0xaa]], + operation: FheOperation::FheGetCiphertext.into(), + result_handles: vec![input_handle.clone()], inputs: vec![sync_input], }; let req = SyncComputeRequest { @@ -31,18 +34,76 @@ async fn get_input_ciphertexts() -> Result<(), Box> { }; let response = client.sync_compute(req).await?; let sync_compute_response = response.get_ref(); - match &sync_compute_response.resp { - Some(Resp::ResultCiphertexts(cts)) => { - match (cts.ciphertexts.first(), cts.ciphertexts.len()) { - (Some(ct), 1) => { - if ct.handle != input_handle || ct.ciphertext.is_empty() { - assert!(false); - } + let resp = as Clone>::clone(&sync_compute_response.resp) + .ok_or_else(|| anyhow!("resp is None"))?; + match resp { + Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { + (Some(ct), 1) => { + if ct.handle != input_handle || ct.ciphertext.is_empty() { + return Err(anyhow!("response handle or ciphertext are unexpected")); } - _ => assert!(false), + Ok(()) } - } - _ => assert!(false), + _ => Err(anyhow!("unexpected amount of result ciphertexts returned")), + }, + Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))), + } +} + +#[tokio::test] +async fn fhe_compute_two_ciphertexts() -> Result<()> { + let test = get_test().await; + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()).await?; + let mut builder = CompactCiphertextListBuilder::new(&test.keys.compact_public_key); + let list = builder.push(10_u16).push(11_u16).build(); + let expander = list.expand_with_key(&test.keys.server_key)?; + let ct1 = SupportedFheCiphertexts::FheUint16( + expander + .get(0) + .ok_or(anyhow!("missing ciphertext at index 0"))??, + ); + let ct1 = test.compress(ct1); + let ct2 = SupportedFheCiphertexts::FheUint16( + expander + .get(1) + .ok_or(anyhow!("missing ciphertext at index 1"))??, + ); + let ct2 = test.compress(ct2); + let sync_input1 = SyncInput { + input: Some(Input::Ciphertext(Ciphertext { + handle: test.ciphertext_handle(&ct1, 3).to_vec(), + ciphertext: ct1, + })), + }; + let sync_input2 = SyncInput { + input: Some(Input::Ciphertext(Ciphertext { + handle: test.ciphertext_handle(&ct2, 3).to_vec(), + ciphertext: ct2, + })), + }; + let computation = SyncComputation { + operation: FheOperation::FheAdd.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1, sync_input2], + }; + let req = SyncComputeRequest { + computations: vec![computation], + input_lists: vec![], + }; + let response = client.sync_compute(req).await?; + let sync_compute_response = response.get_ref(); + let resp = as Clone>::clone(&sync_compute_response.resp) + .ok_or_else(|| anyhow!("resp is None"))?; + match resp { + Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { + (Some(ct), 1) => { + if ct.handle != vec![0xaa; HANDLE_LEN] || ct.ciphertext.is_empty() { + return Err(anyhow!("response handle or ciphertext are unexpected")); + } + Ok(()) + } + _ => Err(anyhow!("unexpected amount of result ciphertexts returned")), + }, + Resp::Error(e) => Err(anyhow!(format!("error response: {}", e))), } - Ok(()) } diff --git a/fhevm-engine/executor/tests/utils.rs b/fhevm-engine/executor/tests/utils.rs index a5ea2f94..1f2a88f7 100644 --- a/fhevm-engine/executor/tests/utils.rs +++ b/fhevm-engine/executor/tests/utils.rs @@ -5,9 +5,10 @@ use executor::{cli::Args, server}; use fhevm_engine_common::{ keys::{FhevmKeys, SerializedFhevmKeys}, tfhe_ops::current_ciphertext_version, - types::Handle, + types::{Handle, SupportedFheCiphertexts}, }; use sha3::{Digest, Keccak256}; +use tfhe::set_server_key; use tokio::{sync::OnceCell, time::sleep}; pub struct TestInstance { @@ -34,12 +35,24 @@ impl TestInstance { } pub fn input_handle(&self, list: &[u8], index: u8, ct_type: u8) -> Handle { - let mut handle: Handle = Keccak256::digest(list).into(); + let mut handle: Handle = Keccak256::digest(list).to_vec(); handle[29] = index; handle[30] = ct_type; handle[31] = current_ciphertext_version() as u8; handle } + + pub fn ciphertext_handle(&self, ciphertext: &[u8], ct_type: u8) -> Handle { + let mut handle: Handle = Keccak256::digest(&ciphertext).to_vec(); + handle[30] = ct_type; + handle[31] = current_ciphertext_version() as u8; + handle + } + + pub fn compress(&self, ct: SupportedFheCiphertexts) -> Vec { + set_server_key(self.keys.server_key.clone()); + ct.compress() + } } static TEST: OnceCell> = OnceCell::const_new(); diff --git a/fhevm-engine/fhevm-engine-common/Cargo.toml b/fhevm-engine/fhevm-engine-common/Cargo.toml index f072e6e2..b3415f26 100644 --- a/fhevm-engine/fhevm-engine-common/Cargo.toml +++ b/fhevm-engine/fhevm-engine-common/Cargo.toml @@ -5,6 +5,12 @@ edition = "2021" [dependencies] tfhe.workspace = true +sha3.workspace = true +anyhow.workspace = true strum = { version = "0.26", features = ["derive"] } bincode = "1.3.3" -hex = "0.4" \ No newline at end of file +hex = "0.4" + +[[bin]] +name = "generate-keys" +path = "src/bin/generate_keys.rs" diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index ab120285..60feab34 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -1,6 +1,8 @@ +use std::error::Error; + use tfhe::integer::U256; use tfhe::prelude::FheDecrypt; -use tfhe::CompressedCiphertextListBuilder; +use tfhe::{CompressedCiphertextList, CompressedCiphertextListBuilder}; #[derive(Debug)] pub enum FhevmError { @@ -72,6 +74,8 @@ pub enum FhevmError { expected_scalar_operand_bytes: usize, got_bytes: usize, }, + BadInputs, + MissingTfheRsData, } impl std::error::Error for FhevmError {} @@ -182,6 +186,12 @@ impl std::fmt::Display for FhevmError { } => { write!(f, "only one operand can be scalar, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, second operand count: {scalar_operand_count}, max scalar operands: {max_scalar_operands}") } + Self::BadInputs => { + write!(f, "Bad inputs") + } + Self::MissingTfheRsData => { + write!(f, "Missing TFHE-rs data") + } } } } @@ -299,6 +309,33 @@ impl SupportedFheCiphertexts { let list = builder.build().expect("ciphertext compression"); bincode::serialize(&list).expect("compressed list serialization") } + + pub fn decompress(ct_type: i16, list: &[u8]) -> Result> { + let list: CompressedCiphertextList = bincode::deserialize(list)?; + match ct_type { + 1 => Ok(SupportedFheCiphertexts::FheBool( + list.get(0)? + .ok_or(Box::new(FhevmError::MissingTfheRsData))?, + )), + 2 => Ok(SupportedFheCiphertexts::FheUint8( + list.get(0)? + .ok_or(Box::new(FhevmError::MissingTfheRsData))?, + )), + 3 => Ok(SupportedFheCiphertexts::FheUint16( + list.get(0)? + .ok_or(Box::new(FhevmError::MissingTfheRsData))?, + )), + 4 => Ok(SupportedFheCiphertexts::FheUint32( + list.get(0)? + .ok_or(Box::new(FhevmError::MissingTfheRsData))?, + )), + 5 => Ok(SupportedFheCiphertexts::FheUint64( + list.get(0)? + .ok_or(Box::new(FhevmError::MissingTfheRsData))?, + )), + _ => Err(Box::new(FhevmError::UnknownFheType(ct_type as i32))), + } + } } impl SupportedFheOperations { @@ -419,4 +456,7 @@ impl From for i16 { } } -pub type Handle = [u8; 32]; +pub type Handle = Vec; +pub const HANDLE_LEN: usize = 32; +pub const SCALAR_LEN: usize = 32; + diff --git a/proto/common.proto b/proto/common.proto index 30ac4ab8..29779489 100644 --- a/proto/common.proto +++ b/proto/common.proto @@ -31,5 +31,5 @@ enum FheOperation { FHE_NOT = 21; FHE_CAST = 30; FHE_IF_THEN_ELSE = 31; - FHE_GET_INPUT_CIPHERTEXT = 32; + FHE_GET_CIPHERTEXT = 32; } diff --git a/proto/executor.proto b/proto/executor.proto index 3ef76a37..9cbe300b 100644 --- a/proto/executor.proto +++ b/proto/executor.proto @@ -50,6 +50,8 @@ enum SyncComputeError { UNSUPPORTED_OPERATION = 2; BAD_INPUTS = 3; UNKNOWN_HANDLE = 4; + COMPUTATION_FAILED = 5; + BAD_RESULT_HANDLES = 6; } // Represents a ciphertext that is an expanded input or a result of FHE computation.