diff --git a/fhevm-engine/coprocessor/src/server.rs b/fhevm-engine/coprocessor/src/server.rs index bb7193a1..45b4ac66 100644 --- a/fhevm-engine/coprocessor/src/server.rs +++ b/fhevm-engine/coprocessor/src/server.rs @@ -7,7 +7,7 @@ use crate::db_queries::{ }; use crate::server::coprocessor::GenericResponse; use crate::types::{CoprocessorError, TfheTenantKeys}; -use crate::utils::sort_computations_by_dependencies; +use crate::utils::{set_server_key_if_not_set, sort_computations_by_dependencies}; use alloy::signers::local::PrivateKeySigner; use alloy::signers::SignerSync; use alloy::sol_types::SolStruct; @@ -23,6 +23,7 @@ use fhevm_engine_common::tfhe_ops::{ use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts, SupportedFheOperations}; use sha3::{Digest, Keccak256}; use sqlx::{query, Acquire}; +use tokio::task::spawn_blocking; use tonic::transport::Server; pub mod common { @@ -177,14 +178,15 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ )) })?; let acl_contract_address = - alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address).map_err(|e| { - tonic::Status::from_error(Box::new( - CoprocessorError::CannotParseTenantEthereumAddress { - bad_address: fetch_key_response.acl_contract_address.clone(), - parsing_error: e.to_string(), - }, - )) - })?; + alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address) + .map_err(|e| { + tonic::Status::from_error(Box::new( + CoprocessorError::CannotParseTenantEthereumAddress { + bad_address: fetch_key_response.acl_contract_address.clone(), + parsing_error: e.to_string(), + }, + )) + })?; let eip_712_domain = alloy::sol_types::eip712_domain! { name: "InputVerifier", @@ -224,12 +226,12 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let server_key = server_key.clone(); tfhe_work_set.spawn_blocking( move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, usize)> { - let expanded = - try_expand_ciphertext_list(&cloned_input.input_payload, &server_key) - .map_err(|e| { - let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e); - (err, idx) - })?; + set_server_key_if_not_set(tenant_id, &server_key); + let expanded = try_expand_ciphertext_list(&cloned_input.input_payload) + .map_err(|e| { + let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e); + (err, idx) + })?; Ok((expanded, idx)) }, @@ -280,7 +282,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ assert_eq!(blob_hash.len(), 32, "should be 32 bytes"); let corresponding_unpacked = results - .get(&idx) + .remove(&idx) .expect("we should have all results computed now"); // save blob for audits and historical reference @@ -320,21 +322,31 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ signer_address: self.signer.address().to_string(), }; - for (ct_idx, the_ct) in corresponding_unpacked.iter().enumerate() { - let (serialized_type, serialized_ct) = the_ct.serialize(); - let ciphertext_version = current_ciphertext_version(); - let mut handle_hash = Keccak256::new(); - handle_hash.update(&blob_hash); - handle_hash.update(&[ct_idx as u8]); - handle_hash.update(acl_contract_address.as_slice()); - handle_hash.update(&chain_id_be); - let mut handle = handle_hash.finalize().to_vec(); - assert_eq!(handle.len(), 32); - // idx cast to u8 must succeed because we don't allow - // more handles than u8 size - handle[29] = ct_idx as u8; - handle[30] = serialized_type as u8; - handle[31] = ciphertext_version as u8; + let ciphertext_version = current_ciphertext_version(); + for (ct_idx, the_ct) in corresponding_unpacked.into_iter().enumerate() { + // TODO: simplify compress and hash computation async handling + let blob_hash_clone = blob_hash.clone(); + let server_key_clone = server_key.clone(); + let (handle, serialized_ct, serialized_type) = spawn_blocking(move || { + set_server_key_if_not_set(tenant_id, &server_key_clone); + let (serialized_type, serialized_ct) = the_ct.compress(); + let mut handle_hash = Keccak256::new(); + handle_hash.update(&blob_hash_clone); + handle_hash.update(&[ct_idx as u8]); + handle_hash.update(acl_contract_address.as_slice()); + handle_hash.update(&chain_id_be); + let mut handle = handle_hash.finalize().to_vec(); + assert_eq!(handle.len(), 32); + // idx cast to u8 must succeed because we don't allow + // more handles than u8 size + handle[29] = ct_idx as u8; + handle[30] = serialized_type as u8; + handle[31] = ciphertext_version as u8; + + (handle, serialized_ct, serialized_type) + }) + .await + .map_err(|e| tonic::Status::from_error(Box::new(e)))?; let _ = sqlx::query!( " @@ -572,7 +584,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let mut res: Vec<(Vec, i16, Vec)> = Vec::with_capacity(cloned.len()); for v in cloned { let ct = trivial_encrypt_be_bytes(v.output_type as i16, &v.be_value); - let (ct_type, ct_bytes) = ct.serialize(); + let (ct_type, ct_bytes) = ct.compress(); res.push((v.handle, ct_type, ct_bytes)); } diff --git a/fhevm-engine/coprocessor/src/tests/utils.rs b/fhevm-engine/coprocessor/src/tests/utils.rs index bb7bd393..4c9ca3b7 100644 --- a/fhevm-engine/coprocessor/src/tests/utils.rs +++ b/fhevm-engine/coprocessor/src/tests/utils.rs @@ -1,5 +1,6 @@ use crate::cli::Args; -use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext}; +use fhevm_engine_common::tfhe_ops::current_ciphertext_version; +use fhevm_engine_common::types::SupportedFheCiphertexts; use rand::Rng; use std::collections::BTreeMap; use std::sync::atomic::{AtomicU16, Ordering}; @@ -240,9 +241,9 @@ pub async fn decrypt_ciphertexts( tenant_id: i32, input: Vec>, ) -> Result, Box> { - let mut priv_key = sqlx::query!( + let mut keys = sqlx::query!( " - SELECT cks_key + SELECT cks_key, sks_key FROM tenants WHERE tenant_id = $1 ", @@ -251,8 +252,8 @@ pub async fn decrypt_ciphertexts( .fetch_all(pool) .await?; - if priv_key.is_empty() || priv_key[0].cks_key.is_none() { - panic!("tenant private key not found"); + if keys.is_empty() || keys[0].cks_key.is_none() { + panic!("tenant keys not found"); } let mut ct_indexes: BTreeMap<&[u8], usize> = BTreeMap::new(); @@ -260,7 +261,7 @@ pub async fn decrypt_ciphertexts( ct_indexes.insert(h.as_slice(), idx); } - assert_eq!(priv_key.len(), 1); + assert_eq!(keys.len(), 1); let cts = sqlx::query!( " @@ -281,15 +282,18 @@ pub async fn decrypt_ciphertexts( panic!("ciphertext not found"); } - let priv_key = priv_key.pop().unwrap().cks_key.unwrap(); + let keys = keys.pop().unwrap(); let mut values = tokio::task::spawn_blocking(move || { - let client_key: tfhe::ClientKey = bincode::deserialize(&priv_key).unwrap(); + let client_key: tfhe::ClientKey = + bincode::deserialize(&keys.cks_key.clone().unwrap()).unwrap(); + let sks: tfhe::ServerKey = bincode::deserialize(&keys.sks_key).unwrap(); + tfhe::set_server_key(sks); let mut decrypted: Vec<(Vec, DecryptionResult)> = Vec::with_capacity(cts.len()); for ct in cts { let deserialized = - deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext).unwrap(); + SupportedFheCiphertexts::decompress(ct.ciphertext_type, &ct.ciphertext).unwrap(); decrypted.push(( ct.handle, DecryptionResult { diff --git a/fhevm-engine/coprocessor/src/tfhe_worker.rs b/fhevm-engine/coprocessor/src/tfhe_worker.rs index 7ce8994e..b5f8a5b4 100644 --- a/fhevm-engine/coprocessor/src/tfhe_worker.rs +++ b/fhevm-engine/coprocessor/src/tfhe_worker.rs @@ -1,12 +1,12 @@ +use crate::utils::set_server_key_if_not_set; use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys}; use fhevm_engine_common::types::SupportedFheCiphertexts; use fhevm_engine_common::{ - tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext, perform_fhe_operation}, + tfhe_ops::{current_ciphertext_version, perform_fhe_operation}, types::SupportedFheOperations, }; use sqlx::{postgres::PgListener, query, Acquire}; use std::{ - cell::Cell, collections::{BTreeSet, HashMap}, num::NonZeroUsize, }; @@ -156,7 +156,8 @@ async fn tfhe_worker_cycle( let mut work_ciphertexts: Vec<(i16, Vec)> = Vec::with_capacity(w.dependencies.len()); for (idx, dh) in w.dependencies.iter().enumerate() { - let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar(); + let is_operand_scalar = + w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar(); if is_operand_scalar { work_ciphertexts.push((-1, dh.clone())); } else { @@ -171,26 +172,20 @@ async fn tfhe_worker_cycle( // copy for setting error in database tfhe_work_set.spawn_blocking( move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, i32, Vec)> { - thread_local! { - static TFHE_TENANT_ID: Cell = Cell::new(-1); - } - - // set thread tenant key + // set the server key if not set { let mut rk = tenant_key_cache.blocking_write(); let keys = rk .get(&w.tenant_id) .expect("Can't get tenant key from cache"); - if w.tenant_id != TFHE_TENANT_ID.get() { - tfhe::set_server_key(keys.sks.clone()); - TFHE_TENANT_ID.set(w.tenant_id); - } + set_server_key_if_not_set(w.tenant_id, &keys.sks); } let mut deserialized_cts: Vec = Vec::with_capacity(work_ciphertexts.len()); for (idx, (ct_type, ct_bytes)) in work_ciphertexts.iter().enumerate() { - let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar(); + let is_operand_scalar = + w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar(); if is_operand_scalar { let mut the_int = tfhe::integer::U256::default(); assert!( @@ -208,24 +203,22 @@ async fn tfhe_worker_cycle( deserialized_cts.push(SupportedFheCiphertexts::Scalar(the_int)); } else { deserialized_cts.push( - deserialize_fhe_ciphertext(*ct_type, ct_bytes.as_slice()).map_err( - |e| { + SupportedFheCiphertexts::decompress(*ct_type, ct_bytes.as_slice()) + .map_err(|e| { let err: Box = - Box::new(e); + e.into(); (err, w.tenant_id, w.output_handle.clone()) - }, - )?, + })?, ); } } let res = - perform_fhe_operation(w.fhe_operation, &deserialized_cts) - .map_err(|e| { - let err: Box = Box::new(e); - (err, w.tenant_id, w.output_handle.clone()) - })?; - let (db_type, db_bytes) = res.serialize(); + perform_fhe_operation(w.fhe_operation, &deserialized_cts).map_err(|e| { + let err: Box = Box::new(e); + (err, w.tenant_id, w.output_handle.clone()) + })?; + let (db_type, db_bytes) = res.compress(); Ok((w, db_type, db_bytes)) }, diff --git a/fhevm-engine/coprocessor/src/utils.rs b/fhevm-engine/coprocessor/src/utils.rs index 494b96af..ab0985f6 100644 --- a/fhevm-engine/coprocessor/src/utils.rs +++ b/fhevm-engine/coprocessor/src/utils.rs @@ -1,4 +1,7 @@ -use std::collections::{BTreeSet, HashMap, HashSet}; +use std::{ + cell::Cell, + collections::{BTreeSet, HashMap, HashSet}, +}; use fhevm_engine_common::types::{FhevmError, SupportedFheOperations}; @@ -321,3 +324,14 @@ fn test_multi_level_circular_dependency_detection() { } } } + +pub fn set_server_key_if_not_set(tenant_id: i32, sks: &tfhe::ServerKey) { + thread_local! { + static TFHE_TENANT_ID: Cell = Cell::new(-1); + } + + if tenant_id != TFHE_TENANT_ID.get() { + tfhe::set_server_key(sks.clone()); + TFHE_TENANT_ID.set(tenant_id); + } +} diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 1cc3d06e..f0472801 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -32,7 +32,7 @@ pub mod executor { pub fn start(args: &crate::cli::Args) -> Result<()> { let keys: Arc = Arc::new(SerializedFhevmKeys::load_from_disk().into()); - let executor = FhevmExecutorService::new(keys.clone()); + let executor = FhevmExecutorService::new(); let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(args.tokio_threads) .max_blocking_threads(args.fhe_compute_threads) @@ -70,9 +70,7 @@ pub struct ComputationState { pub ciphertexts: HashMap, } -struct FhevmExecutorService { - keys: Arc, -} +struct FhevmExecutorService {} #[tonic::async_trait] impl FhevmExecutor for FhevmExecutorService { @@ -80,14 +78,12 @@ impl FhevmExecutor for FhevmExecutorService { &self, req: Request, ) -> Result, Status> { - let keys = self.keys.clone(); let resp = spawn_blocking(move || { let req = req.get_ref(); let mut state = ComputationState::default(); // Exapnd compact ciphertext lists for the whole request. - if Self::expand_compact_lists(&req.compact_ciphertext_lists, &keys, &mut state).is_err() - { + if Self::expand_compact_lists(&req.compact_ciphertext_lists, &mut state).is_err() { return SyncComputeResponse { resp: Some(Resp::Error(SyncComputeError::BadInputList.into())), }; @@ -138,8 +134,8 @@ impl FhevmExecutor for FhevmExecutorService { } impl FhevmExecutorService { - fn new(keys: Arc) -> Self { - FhevmExecutorService { keys } + fn new() -> Self { + FhevmExecutorService {} } #[allow(dead_code)] @@ -166,11 +162,10 @@ impl FhevmExecutorService { fn expand_compact_lists( lists: &Vec>, - keys: &FhevmKeys, state: &mut ComputationState, ) -> Result<(), FhevmError> { for list in lists { - let cts = try_expand_ciphertext_list(&list, &keys.server_key)?; + let cts = try_expand_ciphertext_list(&list)?; let list_hash: Handle = Keccak256::digest(list).to_vec(); for (i, ct) in cts.iter().enumerate() { let mut handle = list_hash.clone(); @@ -181,7 +176,7 @@ impl FhevmExecutorService { handle, InMemoryCiphertext { expanded: ct.clone(), - compressed: ct.clone().compress(), + compressed: ct.clone().compress().1, }, ); } @@ -268,7 +263,7 @@ impl FhevmExecutorService { match inputs { Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) { Ok(result) => { - let compressed = result.clone().compress(); + let (_, compressed) = result.clone().compress(); state.ciphertexts.insert( result_handle.clone(), InMemoryCiphertext { @@ -299,7 +294,7 @@ pub fn run_computation( Ok(FheOperation::FheGetCiphertext) => { let res = InMemoryCiphertext { expanded: inputs[0].clone(), - compressed: inputs[0].clone().compress(), + compressed: inputs[0].clone().compress().1, }; Ok((graph_node_index, res)) } @@ -307,7 +302,7 @@ pub fn run_computation( Ok(result) => { let res = InMemoryCiphertext { expanded: result.clone(), - compressed: result.compress(), + compressed: result.compress().1, }; Ok((graph_node_index, res)) } diff --git a/fhevm-engine/executor/tests/utils.rs b/fhevm-engine/executor/tests/utils.rs index 1f2a88f7..44483940 100644 --- a/fhevm-engine/executor/tests/utils.rs +++ b/fhevm-engine/executor/tests/utils.rs @@ -51,7 +51,7 @@ impl TestInstance { pub fn compress(&self, ct: SupportedFheCiphertexts) -> Vec { set_server_key(self.keys.server_key.clone()); - ct.compress() + ct.compress().1 } } diff --git a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs index 8c503ef1..0245b13f 100644 --- a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs +++ b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs @@ -78,16 +78,18 @@ pub fn deserialize_fhe_ciphertext( pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> SupportedFheCiphertexts { let last_byte = if input_bytes.len() > 0 { input_bytes[input_bytes.len() - 1] - } else { 0 }; + } else { + 0 + }; match output_type { 0 => SupportedFheCiphertexts::FheBool( - FheBool::try_encrypt_trivial(last_byte > 0).unwrap(), + FheBool::try_encrypt_trivial(last_byte > 0).expect("trival encrypt bool"), ), 1 => SupportedFheCiphertexts::FheUint4( - FheUint4::try_encrypt_trivial(last_byte).unwrap() + FheUint4::try_encrypt_trivial(last_byte).expect("trivial encrypt 4"), ), 2 => SupportedFheCiphertexts::FheUint8( - FheUint8::try_encrypt_trivial(last_byte).unwrap(), + FheUint8::try_encrypt_trivial(last_byte).expect("trivial encrypt 8"), ), 3 => { let mut padded: [u8; 2] = [0; 2]; @@ -99,10 +101,13 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); } let res = u16::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint16(FheUint16::try_encrypt_trivial(res).unwrap()) + SupportedFheCiphertexts::FheUint16( + FheUint16::try_encrypt_trivial(res).expect("trivial encrypt 16"), + ) } 4 => { let mut padded: [u8; 4] = [0; 4]; @@ -114,10 +119,13 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); } let res: u32 = u32::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint32(FheUint32::try_encrypt_trivial(res).unwrap()) + SupportedFheCiphertexts::FheUint32( + FheUint32::try_encrypt_trivial(res).expect("trivial encrypt 32"), + ) } 5 => { let mut padded: [u8; 8] = [0; 8]; @@ -129,10 +137,13 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); } let res: u64 = u64::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint64(FheUint64::try_encrypt_trivial(res).unwrap()) + SupportedFheCiphertexts::FheUint64( + FheUint64::try_encrypt_trivial(res).expect("trivial encrypt 64"), + ) } 6 => { let mut padded: [u8; 16] = [0; 16]; @@ -144,10 +155,11 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); } let res: u128 = u128::from_be_bytes(padded); - let output = FheUint128::try_encrypt_trivial(res).unwrap(); + let output = FheUint128::try_encrypt_trivial(res).expect("trivial encrypt 128"); SupportedFheCiphertexts::FheUint128(output) } 7 => { @@ -161,10 +173,13 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); be.copy_from_be_byte_slice(&padded); } - let output: FheUint160 = FheUint256::try_encrypt_trivial(be).unwrap().cast_into(); + let output: FheUint160 = FheUint256::try_encrypt_trivial(be) + .expect("trivial encrypt 160") + .cast_into(); SupportedFheCiphertexts::FheUint160(output) } 8 => { @@ -178,10 +193,11 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); be.copy_from_be_byte_slice(&padded); } - let output = FheUint256::try_encrypt_trivial(be).unwrap(); + let output = FheUint256::try_encrypt_trivial(be).expect("trivial encrypt 256"); SupportedFheCiphertexts::FheUint256(output) } 9 => { @@ -195,10 +211,11 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); be.copy_from_be_byte_slice(&padded); } - let output = FheUint512::try_encrypt_trivial(be).unwrap(); + let output = FheUint512::try_encrypt_trivial(be).expect("trivial encrypt 512"); SupportedFheCiphertexts::FheBytes64(output) } 10 => { @@ -212,10 +229,11 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); be.copy_from_be_byte_slice(&padded); } - let output = FheUint1024::try_encrypt_trivial(be).unwrap(); + let output = FheUint1024::try_encrypt_trivial(be).expect("trivial encrypt 1024"); SupportedFheCiphertexts::FheBytes128(output) } 11 => { @@ -229,10 +247,11 @@ pub fn trivial_encrypt_be_bytes(output_type: i16, input_bytes: &[u8]) -> Support 0 }; let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[input_bytes.len()-len..]); + padded[copy_from..padded_len] + .copy_from_slice(&input_bytes[input_bytes.len() - len..]); be.copy_from_be_byte_slice(&padded); } - let output = FheUint2048::try_encrypt_trivial(be).unwrap(); + let output = FheUint2048::try_encrypt_trivial(be).expect("trivial encrypt 2048"); SupportedFheCiphertexts::FheBytes256(output) } other => { @@ -247,14 +266,11 @@ pub fn current_ciphertext_version() -> i16 { pub fn try_expand_ciphertext_list( input_ciphertext: &[u8], - server_key: &tfhe::ServerKey, ) -> Result, FhevmError> { let mut res = Vec::new(); let the_list: tfhe::ProvenCompactCiphertextList = safe_deserialize(input_ciphertext)?; - // TODO: we can do better and avoid cloning - tfhe::set_server_key(server_key.clone()); let expanded = the_list .expand_without_verification() .map_err(|e| FhevmError::CiphertextExpansionError(e))?; @@ -3210,7 +3226,8 @@ pub fn perform_fhe_operation( panic!("unknown cast pair") } }, - SupportedFheOperations::FheTrivialEncrypt => match (&input_operands[0], &input_operands[1]) { + SupportedFheOperations::FheTrivialEncrypt => match (&input_operands[0], &input_operands[1]) + { (SupportedFheCiphertexts::Scalar(inp), SupportedFheCiphertexts::Scalar(op)) => { let (l, _) = op.to_low_high_u128(); let l = l as i16; diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index d282185f..30d4fca9 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -440,28 +440,29 @@ impl SupportedFheCiphertexts { } } - pub fn compress(self) -> Vec { + pub fn compress(&self) -> (i16, Vec) { + let type_num = self.type_num(); let mut builder = CompressedCiphertextListBuilder::new(); match self { - SupportedFheCiphertexts::FheBool(c) => builder.push(c), - SupportedFheCiphertexts::FheUint4(c) => builder.push(c), - SupportedFheCiphertexts::FheUint8(c) => builder.push(c), - SupportedFheCiphertexts::FheUint16(c) => builder.push(c), - SupportedFheCiphertexts::FheUint32(c) => builder.push(c), - SupportedFheCiphertexts::FheUint64(c) => builder.push(c), - SupportedFheCiphertexts::FheUint128(c) => builder.push(c), - SupportedFheCiphertexts::FheUint160(c) => builder.push(c), - SupportedFheCiphertexts::FheUint256(c) => builder.push(c), - SupportedFheCiphertexts::FheBytes64(c) => builder.push(c), - SupportedFheCiphertexts::FheBytes128(c) => builder.push(c), - SupportedFheCiphertexts::FheBytes256(c) => builder.push(c), + SupportedFheCiphertexts::FheBool(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint4(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint8(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint16(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint32(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint64(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint128(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint160(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheUint256(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheBytes64(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheBytes128(c) => builder.push(c.clone()), + SupportedFheCiphertexts::FheBytes256(c) => builder.push(c.clone()), SupportedFheCiphertexts::Scalar(_) => { // TODO: Need to fix that, scalars are not ciphertexts. panic!("cannot compress a scalar"); } }; let list = builder.build().expect("ciphertext compression"); - safe_serialize(&list) + (type_num, safe_serialize(&list)) } pub fn decompress(ct_type: i16, list: &[u8]) -> Result { @@ -500,7 +501,7 @@ impl SupportedFheCiphertexts { 10 => Ok(SupportedFheCiphertexts::FheBytes128( list.get(0)?.ok_or(FhevmError::MissingTfheRsData)?, )), - 11 => Ok(SupportedFheCiphertexts::FheBytes128( + 11 => Ok(SupportedFheCiphertexts::FheBytes256( list.get(0)?.ok_or(FhevmError::MissingTfheRsData)?, )), _ => Err(FhevmError::UnknownFheType(ct_type as i32).into()), @@ -575,7 +576,9 @@ impl SupportedFheOperations { pub fn does_have_more_than_one_scalar(&self) -> bool { match self { - SupportedFheOperations::FheRand | SupportedFheOperations::FheRandBounded | SupportedFheOperations::FheTrivialEncrypt => true, + SupportedFheOperations::FheRand + | SupportedFheOperations::FheRandBounded + | SupportedFheOperations::FheTrivialEncrypt => true, _ => false, } }