Skip to content

Commit

Permalink
feat: add ciphertext compression in the coprocessor
Browse files Browse the repository at this point in the history
Move compression and hash computation to a spawn_blocking call during
input upload.

Fix a bug where decompress was treating type 11 as bytes128 instead of
bytes256.
  • Loading branch information
dartdart26 committed Sep 25, 2024
1 parent 492f44e commit 87160fd
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 124 deletions.
76 changes: 44 additions & 32 deletions fhevm-engine/coprocessor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::db_queries::{
};
use crate::server::coprocessor::GenericResponse;
use crate::types::{CoprocessorError, TfheTenantKeys};
use crate::utils::sort_computations_by_dependencies;
use crate::utils::{set_server_key_if_not_set, sort_computations_by_dependencies};
use alloy::signers::local::PrivateKeySigner;
use alloy::signers::SignerSync;
use alloy::sol_types::SolStruct;
Expand All @@ -23,6 +23,7 @@ use fhevm_engine_common::tfhe_ops::{
use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts, SupportedFheOperations};
use sha3::{Digest, Keccak256};
use sqlx::{query, Acquire};
use tokio::task::spawn_blocking;
use tonic::transport::Server;

pub mod common {
Expand Down Expand Up @@ -177,14 +178,15 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
))
})?;
let acl_contract_address =
alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address).map_err(|e| {
tonic::Status::from_error(Box::new(
CoprocessorError::CannotParseTenantEthereumAddress {
bad_address: fetch_key_response.acl_contract_address.clone(),
parsing_error: e.to_string(),
},
))
})?;
alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address)
.map_err(|e| {
tonic::Status::from_error(Box::new(
CoprocessorError::CannotParseTenantEthereumAddress {
bad_address: fetch_key_response.acl_contract_address.clone(),
parsing_error: e.to_string(),
},
))
})?;

let eip_712_domain = alloy::sol_types::eip712_domain! {
name: "InputVerifier",
Expand Down Expand Up @@ -224,12 +226,12 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
let server_key = server_key.clone();
tfhe_work_set.spawn_blocking(
move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, usize)> {
let expanded =
try_expand_ciphertext_list(&cloned_input.input_payload, &server_key)
.map_err(|e| {
let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e);
(err, idx)
})?;
set_server_key_if_not_set(tenant_id, &server_key);
let expanded = try_expand_ciphertext_list(&cloned_input.input_payload)
.map_err(|e| {
let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e);
(err, idx)
})?;

Ok((expanded, idx))
},
Expand Down Expand Up @@ -280,7 +282,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
assert_eq!(blob_hash.len(), 32, "should be 32 bytes");

let corresponding_unpacked = results
.get(&idx)
.remove(&idx)
.expect("we should have all results computed now");

// save blob for audits and historical reference
Expand Down Expand Up @@ -320,21 +322,31 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
signer_address: self.signer.address().to_string(),
};

for (ct_idx, the_ct) in corresponding_unpacked.iter().enumerate() {
let (serialized_type, serialized_ct) = the_ct.serialize();
let ciphertext_version = current_ciphertext_version();
let mut handle_hash = Keccak256::new();
handle_hash.update(&blob_hash);
handle_hash.update(&[ct_idx as u8]);
handle_hash.update(acl_contract_address.as_slice());
handle_hash.update(&chain_id_be);
let mut handle = handle_hash.finalize().to_vec();
assert_eq!(handle.len(), 32);
// idx cast to u8 must succeed because we don't allow
// more handles than u8 size
handle[29] = ct_idx as u8;
handle[30] = serialized_type as u8;
handle[31] = ciphertext_version as u8;
let ciphertext_version = current_ciphertext_version();
for (ct_idx, the_ct) in corresponding_unpacked.into_iter().enumerate() {
// TODO: simplify compress and hash computation async handling
let blob_hash_clone = blob_hash.clone();
let server_key_clone = server_key.clone();
let (handle, serialized_ct, serialized_type) = spawn_blocking(move || {
set_server_key_if_not_set(tenant_id, &server_key_clone);
let (serialized_type, serialized_ct) = the_ct.compress();
let mut handle_hash = Keccak256::new();
handle_hash.update(&blob_hash_clone);
handle_hash.update(&[ct_idx as u8]);
handle_hash.update(acl_contract_address.as_slice());
handle_hash.update(&chain_id_be);
let mut handle = handle_hash.finalize().to_vec();
assert_eq!(handle.len(), 32);
// idx cast to u8 must succeed because we don't allow
// more handles than u8 size
handle[29] = ct_idx as u8;
handle[30] = serialized_type as u8;
handle[31] = ciphertext_version as u8;

(handle, serialized_ct, serialized_type)
})
.await
.map_err(|e| tonic::Status::from_error(Box::new(e)))?;

let _ = sqlx::query!(
"
Expand Down Expand Up @@ -572,7 +584,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
let mut res: Vec<(Vec<u8>, i16, Vec<u8>)> = Vec::with_capacity(cloned.len());
for v in cloned {
let ct = trivial_encrypt_be_bytes(v.output_type as i16, &v.be_value);
let (ct_type, ct_bytes) = ct.serialize();
let (ct_type, ct_bytes) = ct.compress();
res.push((v.handle, ct_type, ct_bytes));
}

Expand Down
22 changes: 13 additions & 9 deletions fhevm-engine/coprocessor/src/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cli::Args;
use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext};
use fhevm_engine_common::tfhe_ops::current_ciphertext_version;
use fhevm_engine_common::types::SupportedFheCiphertexts;
use rand::Rng;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU16, Ordering};
Expand Down Expand Up @@ -240,9 +241,9 @@ pub async fn decrypt_ciphertexts(
tenant_id: i32,
input: Vec<Vec<u8>>,
) -> Result<Vec<DecryptionResult>, Box<dyn std::error::Error>> {
let mut priv_key = sqlx::query!(
let mut keys = sqlx::query!(
"
SELECT cks_key
SELECT cks_key, sks_key
FROM tenants
WHERE tenant_id = $1
",
Expand All @@ -251,16 +252,16 @@ pub async fn decrypt_ciphertexts(
.fetch_all(pool)
.await?;

if priv_key.is_empty() || priv_key[0].cks_key.is_none() {
panic!("tenant private key not found");
if keys.is_empty() || keys[0].cks_key.is_none() {
panic!("tenant keys not found");
}

let mut ct_indexes: BTreeMap<&[u8], usize> = BTreeMap::new();
for (idx, h) in input.iter().enumerate() {
ct_indexes.insert(h.as_slice(), idx);
}

assert_eq!(priv_key.len(), 1);
assert_eq!(keys.len(), 1);

let cts = sqlx::query!(
"
Expand All @@ -281,15 +282,18 @@ pub async fn decrypt_ciphertexts(
panic!("ciphertext not found");
}

let priv_key = priv_key.pop().unwrap().cks_key.unwrap();
let keys = keys.pop().unwrap();

let mut values = tokio::task::spawn_blocking(move || {
let client_key: tfhe::ClientKey = bincode::deserialize(&priv_key).unwrap();
let client_key: tfhe::ClientKey =
bincode::deserialize(&keys.cks_key.clone().unwrap()).unwrap();
let sks: tfhe::ServerKey = bincode::deserialize(&keys.sks_key).unwrap();
tfhe::set_server_key(sks);

let mut decrypted: Vec<(Vec<u8>, DecryptionResult)> = Vec::with_capacity(cts.len());
for ct in cts {
let deserialized =
deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext).unwrap();
SupportedFheCiphertexts::decompress(ct.ciphertext_type, &ct.ciphertext).unwrap();
decrypted.push((
ct.handle,
DecryptionResult {
Expand Down
41 changes: 17 additions & 24 deletions fhevm-engine/coprocessor/src/tfhe_worker.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::utils::set_server_key_if_not_set;
use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys};
use fhevm_engine_common::types::SupportedFheCiphertexts;
use fhevm_engine_common::{
tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext, perform_fhe_operation},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation},
types::SupportedFheOperations,
};
use sqlx::{postgres::PgListener, query, Acquire};
use std::{
cell::Cell,
collections::{BTreeSet, HashMap},
num::NonZeroUsize,
};
Expand Down Expand Up @@ -156,7 +156,8 @@ async fn tfhe_worker_cycle(
let mut work_ciphertexts: Vec<(i16, Vec<u8>)> =
Vec::with_capacity(w.dependencies.len());
for (idx, dh) in w.dependencies.iter().enumerate() {
let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
let is_operand_scalar =
w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
if is_operand_scalar {
work_ciphertexts.push((-1, dh.clone()));
} else {
Expand All @@ -171,26 +172,20 @@ async fn tfhe_worker_cycle(
// copy for setting error in database
tfhe_work_set.spawn_blocking(
move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, i32, Vec<u8>)> {
thread_local! {
static TFHE_TENANT_ID: Cell<i32> = Cell::new(-1);
}

// set thread tenant key
// set the server key if not set
{
let mut rk = tenant_key_cache.blocking_write();
let keys = rk
.get(&w.tenant_id)
.expect("Can't get tenant key from cache");
if w.tenant_id != TFHE_TENANT_ID.get() {
tfhe::set_server_key(keys.sks.clone());
TFHE_TENANT_ID.set(w.tenant_id);
}
set_server_key_if_not_set(w.tenant_id, &keys.sks);
}

let mut deserialized_cts: Vec<SupportedFheCiphertexts> =
Vec::with_capacity(work_ciphertexts.len());
for (idx, (ct_type, ct_bytes)) in work_ciphertexts.iter().enumerate() {
let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
let is_operand_scalar =
w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
if is_operand_scalar {
let mut the_int = tfhe::integer::U256::default();
assert!(
Expand All @@ -208,24 +203,22 @@ async fn tfhe_worker_cycle(
deserialized_cts.push(SupportedFheCiphertexts::Scalar(the_int));
} else {
deserialized_cts.push(
deserialize_fhe_ciphertext(*ct_type, ct_bytes.as_slice()).map_err(
|e| {
SupportedFheCiphertexts::decompress(*ct_type, ct_bytes.as_slice())
.map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> =
Box::new(e);
e.into();
(err, w.tenant_id, w.output_handle.clone())
},
)?,
})?,
);
}
}

let res =
perform_fhe_operation(w.fhe_operation, &deserialized_cts)
.map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
(err, w.tenant_id, w.output_handle.clone())
})?;
let (db_type, db_bytes) = res.serialize();
perform_fhe_operation(w.fhe_operation, &deserialized_cts).map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
(err, w.tenant_id, w.output_handle.clone())
})?;
let (db_type, db_bytes) = res.compress();

Ok((w, db_type, db_bytes))
},
Expand Down
16 changes: 15 additions & 1 deletion fhevm-engine/coprocessor/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::{BTreeSet, HashMap, HashSet};
use std::{
cell::Cell,
collections::{BTreeSet, HashMap, HashSet},
};

use fhevm_engine_common::types::{FhevmError, SupportedFheOperations};

Expand Down Expand Up @@ -321,3 +324,14 @@ fn test_multi_level_circular_dependency_detection() {
}
}
}

pub fn set_server_key_if_not_set(tenant_id: i32, sks: &tfhe::ServerKey) {
thread_local! {
static TFHE_TENANT_ID: Cell<i32> = Cell::new(-1);
}

if tenant_id != TFHE_TENANT_ID.get() {
tfhe::set_server_key(sks.clone());
TFHE_TENANT_ID.set(tenant_id);
}
}
25 changes: 10 additions & 15 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub mod executor {

pub fn start(args: &crate::cli::Args) -> Result<()> {
let keys: Arc<FhevmKeys> = Arc::new(SerializedFhevmKeys::load_from_disk().into());
let executor = FhevmExecutorService::new(keys.clone());
let executor = FhevmExecutorService::new();
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(args.tokio_threads)
.max_blocking_threads(args.fhe_compute_threads)
Expand Down Expand Up @@ -70,24 +70,20 @@ pub struct ComputationState {
pub ciphertexts: HashMap<Handle, InMemoryCiphertext>,
}

struct FhevmExecutorService {
keys: Arc<FhevmKeys>,
}
struct FhevmExecutorService {}

#[tonic::async_trait]
impl FhevmExecutor for FhevmExecutorService {
async fn sync_compute(
&self,
req: Request<SyncComputeRequest>,
) -> Result<Response<SyncComputeResponse>, Status> {
let keys = self.keys.clone();
let resp = spawn_blocking(move || {
let req = req.get_ref();
let mut state = ComputationState::default();

// Exapnd compact ciphertext lists for the whole request.
if Self::expand_compact_lists(&req.compact_ciphertext_lists, &keys, &mut state).is_err()
{
if Self::expand_compact_lists(&req.compact_ciphertext_lists, &mut state).is_err() {
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputList.into())),
};
Expand Down Expand Up @@ -138,8 +134,8 @@ impl FhevmExecutor for FhevmExecutorService {
}

impl FhevmExecutorService {
fn new(keys: Arc<FhevmKeys>) -> Self {
FhevmExecutorService { keys }
fn new() -> Self {
FhevmExecutorService {}
}

#[allow(dead_code)]
Expand All @@ -166,11 +162,10 @@ impl FhevmExecutorService {

fn expand_compact_lists(
lists: &Vec<Vec<u8>>,
keys: &FhevmKeys,
state: &mut ComputationState,
) -> Result<(), FhevmError> {
for list in lists {
let cts = try_expand_ciphertext_list(&list, &keys.server_key)?;
let cts = try_expand_ciphertext_list(&list)?;
let list_hash: Handle = Keccak256::digest(list).to_vec();
for (i, ct) in cts.iter().enumerate() {
let mut handle = list_hash.clone();
Expand All @@ -181,7 +176,7 @@ impl FhevmExecutorService {
handle,
InMemoryCiphertext {
expanded: ct.clone(),
compressed: ct.clone().compress(),
compressed: ct.clone().compress().1,
},
);
}
Expand Down Expand Up @@ -268,7 +263,7 @@ impl FhevmExecutorService {
match inputs {
Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) {
Ok(result) => {
let compressed = result.clone().compress();
let (_, compressed) = result.clone().compress();
state.ciphertexts.insert(
result_handle.clone(),
InMemoryCiphertext {
Expand Down Expand Up @@ -299,15 +294,15 @@ pub fn run_computation(
Ok(FheOperation::FheGetCiphertext) => {
let res = InMemoryCiphertext {
expanded: inputs[0].clone(),
compressed: inputs[0].clone().compress(),
compressed: inputs[0].clone().compress().1,
};
Ok((graph_node_index, res))
}
Ok(_) => match perform_fhe_operation(operation as i16, &inputs) {
Ok(result) => {
let res = InMemoryCiphertext {
expanded: result.clone(),
compressed: result.compress(),
compressed: result.compress().1,
};
Ok((graph_node_index, res))
}
Expand Down
Loading

0 comments on commit 87160fd

Please sign in to comment.