Skip to content

Commit

Permalink
fix: add arbitrary size scalar support
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zk committed Nov 28, 2024
1 parent a630534 commit d7a5472
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 24 deletions.
45 changes: 22 additions & 23 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use fhevm_engine_common::{
common::FheOperation,
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list},
types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN},
};
use sha3::{Digest, Keccak256};
use std::{cell::Cell, collections::HashMap};
Expand Down Expand Up @@ -264,10 +264,9 @@ impl FhevmExecutorService {
let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?;
Ok(ct.expanded.clone())
}
Input::Scalar(s) if s.len() == SCALAR_LEN => {
Input::Scalar(s) => {
Ok(SupportedFheCiphertexts::Scalar(s.clone()))
}
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
})
Expand Down Expand Up @@ -305,7 +304,7 @@ pub fn build_taskgraph_from_request(
let mut produced_handles: HashMap<&Handle, usize> = HashMap::new();
// Add all computations as nodes in the graph.
for computation in &req.computations {
let inputs: Result<Vec<DFGTaskInput>> = computation
let inputs = computation
.inputs
.iter()
.map(|input| match &input.input {
Expand All @@ -317,29 +316,29 @@ pub fn build_taskgraph_from_request(
Ok(DFGTaskInput::Dependence(None))
}
}
Input::Scalar(s) if s.len() == SCALAR_LEN => Ok(DFGTaskInput::Value(
Input::Scalar(s) => Ok(DFGTaskInput::Value(
SupportedFheCiphertexts::Scalar(s.clone()),
)),
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
None => Err(SyncComputeError::BadInputs),
})
.collect();
if let Ok(mut inputs) = inputs {
let res_handle = computation
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or(SyncComputeError::BadResultHandles)?;
let n = dfg
.add_node(
res_handle.clone(),
computation.operation,
std::mem::take(&mut inputs),
)
.or_else(|_| Err(SyncComputeError::ComputationFailed))?;
produced_handles.insert(res_handle, n.index());
}
.collect::<Result<Vec<DFGTaskInput>, SyncComputeError>>();

let mut inputs = inputs?;

let res_handle = computation
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or(SyncComputeError::BadResultHandles)?;
let n = dfg
.add_node(
res_handle.clone(),
computation.operation,
std::mem::take(&mut inputs),
)
.or_else(|_| Err(SyncComputeError::ComputationFailed))?;
produced_handles.insert(res_handle, n.index());
}
// Traverse computations and add dependences/edges as required
for (index, computation) in req.computations.iter().enumerate() {
Expand Down
50 changes: 50 additions & 0 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,53 @@ async fn compute_on_result_ciphertext() {
Resp::Error(e) => assert!(false, "error response: {}", e),
}
}

#[tokio::test]
async fn trivial_encryption_scalar_less_than_32_bytes() {
let test = get_test().await;
test.keys.set_server_key_for_current_thread();
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
// 10 big endian
let mut triv_encrypt_input = vec![0; 31];
triv_encrypt_input.push(10);
let sync_input1 = SyncInput {
input: Some(Input::Scalar(triv_encrypt_input)),
};
let sync_input2 = SyncInput {
input: Some(Input::Scalar(vec![3])),
};
let computation = SyncComputation {
operation: FheOperation::FheTrivialEncrypt.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation],
compact_ciphertext_lists: vec![],
compressed_ciphertexts: vec![],
};
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] {
panic!("response handle is unexpected: {:?}", ct.handle);
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"10" => (),
s => assert!(false, "unexpected result: {}", s),
}
}
_ => panic!("unexpected amount of result ciphertexts returned: {}", cts.ciphertexts.len()),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
}
}
1 change: 0 additions & 1 deletion fhevm-engine/fhevm-engine-common/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,6 @@ impl From<SupportedFheOperations> for i16 {

pub type Handle = Vec<u8>;
pub const HANDLE_LEN: usize = 32;
pub const SCALAR_LEN: usize = 32;

pub fn get_ct_type(handle: &[u8]) -> Result<i16, FhevmError> {
match handle.len() {
Expand Down

0 comments on commit d7a5472

Please sign in to comment.