Skip to content

Commit

Permalink
feat: use versioned safe ser/deser
Browse files Browse the repository at this point in the history
Use the versioned safe ser/deser tfhe-rs functions. Versioning might
become part of safe ser/deser, but for now we better be explicit and not
rely on that.

`ProvenCompactCiphertextList` uses safe_(de)serialize() now as
versioning is not supported in thfe-rs yet. Will move to versioned once
available.

`CompactPkePublicParams` uses (de)serialize_with_mode() as of now as
versioning is not supported yet.

Also, remove explicit thread local and Arc handling for server key - it
is already an Arc internally, so it is cheap to clone.
  • Loading branch information
dartdart26 committed Sep 25, 2024
1 parent d799643 commit 0ebd723
Show file tree
Hide file tree
Showing 20 changed files with 191 additions and 148 deletions.
12 changes: 6 additions & 6 deletions fhevm-engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions fhevm-engine/coprocessor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ version = "0.1.0"
edition = "2021"

[target.'cfg(target_arch = "x86_64")'.dependencies]
tfhe = { version = "0.8.0-alpha.8", features = ["boolean", "shortint", "integer", "x86_64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] }
tfhe = { version = "0.8.0-alpha.9", features = ["boolean", "shortint", "integer", "x86_64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] }
[target.'cfg(target_arch = "aarch64")'.dependencies]
tfhe = { version = "0.8.0-alpha.8", features = ["boolean", "shortint", "integer", "aarch64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] }
tfhe = { version = "0.8.0-alpha.9", features = ["boolean", "shortint", "integer", "aarch64-unix", "zk-pok", "experimental-force_fft_algo_dif4"] }

[dependencies]
# Common dependencies
Expand Down
5 changes: 3 additions & 2 deletions fhevm-engine/coprocessor/src/db_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::{BTreeSet, HashMap};
use std::str::FromStr;

use crate::types::{CoprocessorError, TfheTenantKeys};
use fhevm_engine_common::utils::{safe_deserialize_versioned, safe_deserialize_versioned_sks};
use sqlx::{query, Postgres};

/// Returns tenant id upon valid authorization request
Expand Down Expand Up @@ -148,9 +149,9 @@ where
.await?;

for key in keys {
let sks: tfhe::ServerKey = bincode::deserialize(&key.sks_key)
let sks: tfhe::ServerKey = safe_deserialize_versioned_sks(&key.sks_key)
.expect("We can't deserialize our own validated sks key");
let pks: tfhe::CompactPublicKey = bincode::deserialize(&key.pks_key)
let pks: tfhe::CompactPublicKey = safe_deserialize_versioned(&key.pks_key)
.expect("We can't deserialize our own validated pks key");
let public_params = <tfhe::zk::CompactPkePublicParams as tfhe::zk::CanonicalDeserialize>::deserialize_with_mode(
&*key.public_params,
Expand Down
77 changes: 50 additions & 27 deletions fhevm-engine/coprocessor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::db_queries::{
};
use crate::server::coprocessor::GenericResponse;
use crate::types::{CoprocessorError, TfheTenantKeys};
use crate::utils::{set_server_key_if_not_set, sort_computations_by_dependencies};
use crate::utils::sort_computations_by_dependencies;
use alloy::signers::local::PrivateKeySigner;
use alloy::signers::SignerSync;
use alloy::sol_types::SolStruct;
Expand All @@ -21,12 +21,13 @@ use fhevm_engine_common::tfhe_ops::{
try_expand_ciphertext_list, validate_fhe_type,
};
use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts, SupportedFheOperations};
use fhevm_engine_common::utils::safe_deserialize_versioned_sks;
use lazy_static::lazy_static;
use prometheus::{register_int_counter, IntCounter};
use sha3::{Digest, Keccak256};
use sqlx::{query, Acquire};
use tokio::task::spawn_blocking;
use tonic::transport::Server;
use lazy_static::lazy_static;

pub mod common {
tonic::include_proto!("fhevm.common");
Expand All @@ -37,22 +38,46 @@ pub mod coprocessor {
}

lazy_static! {
static ref UPLOAD_INPUTS_COUNTER: IntCounter =
register_int_counter!("coprocessor_upload_inputs_count", "grpc calls for inputs upload endpoint").unwrap();
static ref UPLOAD_INPUTS_ERRORS: IntCounter =
register_int_counter!("coprocessor_upload_inputs_errors", "grpc errors while calling upload inputs").unwrap();
static ref ASYNC_COMPUTE_COUNTER: IntCounter =
register_int_counter!("coprocessor_async_compute_count", "grpc calls for async compute endpoint").unwrap();
static ref ASYNC_COMPUTE_ERRORS: IntCounter =
register_int_counter!("coprocessor_async_compute_errors", "grpc errors while calling async compute").unwrap();
static ref TRIVIAL_ENCRYPT_COUNTER: IntCounter =
register_int_counter!("coprocessor_trivial_encrypt_count", "grpc calls for trivial encrypt endpoint").unwrap();
static ref TRIVIAL_ENCRYPT_ERRORS: IntCounter =
register_int_counter!("coprocessor_trivial_encrypt_errors", "grpc errors while calling trivial encrypt").unwrap();
static ref GET_CIPHERTEXTS_COUNTER: IntCounter =
register_int_counter!("coprocessor_get_ciphertexts_count", "grpc calls for get ciphertexts endpoint").unwrap();
static ref GET_CIPHERTEXTS_ERRORS: IntCounter =
register_int_counter!("coprocessor_get_ciphertexts_errors", "grpc errors while calling get ciphertexts").unwrap();
static ref UPLOAD_INPUTS_COUNTER: IntCounter = register_int_counter!(
"coprocessor_upload_inputs_count",
"grpc calls for inputs upload endpoint"
)
.unwrap();
static ref UPLOAD_INPUTS_ERRORS: IntCounter = register_int_counter!(
"coprocessor_upload_inputs_errors",
"grpc errors while calling upload inputs"
)
.unwrap();
static ref ASYNC_COMPUTE_COUNTER: IntCounter = register_int_counter!(
"coprocessor_async_compute_count",
"grpc calls for async compute endpoint"
)
.unwrap();
static ref ASYNC_COMPUTE_ERRORS: IntCounter = register_int_counter!(
"coprocessor_async_compute_errors",
"grpc errors while calling async compute"
)
.unwrap();
static ref TRIVIAL_ENCRYPT_COUNTER: IntCounter = register_int_counter!(
"coprocessor_trivial_encrypt_count",
"grpc calls for trivial encrypt endpoint"
)
.unwrap();
static ref TRIVIAL_ENCRYPT_ERRORS: IntCounter = register_int_counter!(
"coprocessor_trivial_encrypt_errors",
"grpc errors while calling trivial encrypt"
)
.unwrap();
static ref GET_CIPHERTEXTS_COUNTER: IntCounter = register_int_counter!(
"coprocessor_get_ciphertexts_count",
"grpc calls for get ciphertexts endpoint"
)
.unwrap();
static ref GET_CIPHERTEXTS_ERRORS: IntCounter = register_int_counter!(
"coprocessor_get_ciphertexts_errors",
"grpc errors while calling get ciphertexts"
)
.unwrap();
}

pub struct CoprocessorService {
Expand Down Expand Up @@ -189,9 +214,9 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
request: tonic::Request<coprocessor::TrivialEncryptBatch>,
) -> std::result::Result<tonic::Response<coprocessor::GenericResponse>, tonic::Status> {
TRIVIAL_ENCRYPT_COUNTER.inc();
self.trivial_encrypt_ciphertexts_impl(request).await.inspect_err(|_| {
TRIVIAL_ENCRYPT_ERRORS.inc()
})
self.trivial_encrypt_ciphertexts_impl(request)
.await
.inspect_err(|_| TRIVIAL_ENCRYPT_ERRORS.inc())
}

async fn get_ciphertexts(
Expand Down Expand Up @@ -270,8 +295,6 @@ impl CoprocessorService {

let mut tfhe_work_set = tokio::task::JoinSet::new();

// server key is biiig, clone the pointer
let server_key = std::sync::Arc::new(server_key);
let mut contract_addresses = Vec::with_capacity(req.input_ciphertexts.len());
let mut user_addresses = Vec::with_capacity(req.input_ciphertexts.len());
for ci in &req.input_ciphertexts {
Expand Down Expand Up @@ -299,7 +322,7 @@ impl CoprocessorService {
let server_key = server_key.clone();
tfhe_work_set.spawn_blocking(
move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, usize)> {
set_server_key_if_not_set(tenant_id, &server_key);
tfhe::set_server_key(server_key.clone());
let expanded = try_expand_ciphertext_list(&cloned_input.input_payload)
.map_err(|e| {
let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e);
Expand Down Expand Up @@ -401,7 +424,7 @@ impl CoprocessorService {
let blob_hash_clone = blob_hash.clone();
let server_key_clone = server_key.clone();
let (handle, serialized_ct, serialized_type) = spawn_blocking(move || {
set_server_key_if_not_set(tenant_id, &server_key_clone);
tfhe::set_server_key(server_key_clone);
let (serialized_type, serialized_ct) = the_ct.compress();
let mut handle_hash = Keccak256::new();
handle_hash.update(&blob_hash_clone);
Expand Down Expand Up @@ -643,7 +666,7 @@ impl CoprocessorService {
let sks = sks.pop().unwrap();
let cloned = req.values.clone();
let out_cts = tokio::task::spawn_blocking(move || {
let server_key: tfhe::ServerKey = bincode::deserialize(&sks.sks_key).unwrap();
let server_key: tfhe::ServerKey = safe_deserialize_versioned_sks(&sks.sks_key).unwrap();
tfhe::set_server_key(server_key);

// single threaded implementation, we can optimize later
Expand Down Expand Up @@ -742,4 +765,4 @@ impl CoprocessorService {

return Ok(tonic::Response::new(result));
}
}
}
2 changes: 1 addition & 1 deletion fhevm-engine/coprocessor/src/tests/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
},
},
tests::{
inputs::{test_random_user_address, test_random_contract_address},
inputs::{test_random_contract_address, test_random_user_address},
utils::{default_api_key, default_tenant_id, setup_test_app},
},
};
Expand Down
42 changes: 21 additions & 21 deletions fhevm-engine/coprocessor/src/tests/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ async fn test_fhe_random_basic() -> Result<(), Box<dyn std::error::Error>> {
let decrypt_request = output_handles.clone();
let resp = decrypt_ciphertexts(&pool, 1, decrypt_request).await?;
let expected: Vec<DecryptionResult> = vec![
DecryptionResult { value: "true".to_string(), output_type: 0 },
DecryptionResult { value: "6".to_string(), output_type: 1 },
DecryptionResult { value: "6".to_string(), output_type: 2 },
DecryptionResult { value: "23046".to_string(), output_type: 3 },
DecryptionResult { value: "2257672710".to_string(), output_type: 4 },
DecryptionResult { value: "12138718414261803526".to_string(), output_type: 5 },
DecryptionResult { value: "130536719590611940049803920731387550214".to_string(), output_type: 6 },
DecryptionResult { value: "971176705489787087023559718483701127113677560326".to_string(), output_type: 7 },
DecryptionResult { value: "62210255757460412253332620363065848989112923584999887570035464828426661222918".to_string(), output_type: 8 },
DecryptionResult { value: "167958935840398111366003661819132943572579228212385323643009044778284654758971531763634195717060767316412295162146605242695852136468800900790045270694406".to_string(), output_type: 9 },
DecryptionResult { value: "127460563385689404084570635453516642982330737396307363709535669246693726363369279326274116849562765049033667934125131507607869225026009107310544028242879211116101076829363291657387574479716476869613221980036198477470920343187777849916436388023322996436007563319615378730113313056846971613305517149919649028614".to_string(), output_type: 10 },
DecryptionResult { value: "29687326363179539154232170826093317060572491263948154715413122357200687474061448043555291795321984983113829977114301561317315809196828773909981565653610082891472340553741585442577497506409472143098823132371629384036451019214072899732235656145602725111017828708028912154841404994944466545632048686969494346234325709069045453046020648098209481065154942201598888424765642988091655940417557742117518483932517015160272576663001732809302519121630949039706341063098676812339442637939392896074884484156187775746589025164758187166306751922076107008755031211360389068550389609734783888124482836062055425119177200121882346609158".to_string(), output_type: 11 }
DecryptionResult { value: "false".to_string(), output_type: 0 },
DecryptionResult { value: "0".to_string(), output_type: 1 },
DecryptionResult { value: "208".to_string(), output_type: 2 },
DecryptionResult { value: "18384".to_string(), output_type: 3 },
DecryptionResult { value: "546654160".to_string(), output_type: 4 },
DecryptionResult { value: "5309702218429319120".to_string(), output_type: 5 },
DecryptionResult { value: "259532418979733305351113308189425485776".to_string(), output_type: 6 },
DecryptionResult { value: "1267867068236038429610894726805758983204846782416".to_string(), output_type: 7 },
DecryptionResult { value: "30164121005611094063260660418297722452727051469520960924993268496981611071440".to_string(), output_type: 8 },
DecryptionResult { value: "6213722850039064669433671652647529255789532115960532372768030528069295575413451947625020803543441293820836550947188152513313074353655122817702324075841488".to_string(), output_type: 9 },
DecryptionResult { value: "135596181018050014151440026328436435096551725053732154262022484620871840987960807275950086972107187870322269980270968306992213801802722113379074637321577464212174225858453662873604986351675422583583022204649943087340917217757907768317732046913567629376792967659037503980904999476225028235881045286170343589840".to_string(), output_type: 10 },
DecryptionResult { value: "20926228578992516717417088826312477563346530647238102649041733547761442048502490725118956138164779152604543876429923941857666356145969053143096763294553340917748520683808183210744459554376313461218407555416764225017989632233017323561671951549366342206122891759465623534776122756268283899411003375412445896653512731607929423488100112560489030070895391125926875840684946772723738520663564318353443151393937189264351070849901049341958499762590673504671858266930627927728298335449590844376609353638317270493141629796367937654152469055194322006100677265326779422590457512219671735471500255183382079593262797753352340195280".to_string(), output_type: 11 }
];

println!("results: {:#?}", resp);
Expand Down Expand Up @@ -186,15 +186,15 @@ async fn test_fhe_random_bounded() -> Result<(), Box<dyn std::error::Error>> {
"28948022309329048855892746252171976963317496166410141009864396001978282409984",
];
let results = [
"true",
"2",
"6",
"6662",
"110189062",
"2915346377407027718",
"45466127860377324183960268873445497350",
"240425886824335627921717302125559617285711288838",
"4314211138802314541547127858721895062477931252179605550306672824470096402950",
"false",
"0",
"80",
"2000",
"546654160",
"698016200001931216",
"4320643789029457753582352615599327184",
"171740840237861240958131102268546718462897375184",
"1216098696282045207367914166125745489409555303110819915128872495003328661456",
];

for (idx, the_type) in supported_types().iter().enumerate() {
Expand Down
5 changes: 3 additions & 2 deletions fhevm-engine/coprocessor/src/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::cli::Args;
use fhevm_engine_common::tfhe_ops::current_ciphertext_version;
use fhevm_engine_common::types::SupportedFheCiphertexts;
use fhevm_engine_common::utils::{safe_deserialize_versioned, safe_deserialize_versioned_sks};
use rand::Rng;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU16, Ordering};
Expand Down Expand Up @@ -287,8 +288,8 @@ pub async fn decrypt_ciphertexts(

let mut values = tokio::task::spawn_blocking(move || {
let client_key: tfhe::ClientKey =
bincode::deserialize(&keys.cks_key.clone().unwrap()).unwrap();
let sks: tfhe::ServerKey = bincode::deserialize(&keys.sks_key).unwrap();
safe_deserialize_versioned(&keys.cks_key.clone().unwrap()).unwrap();
let sks: tfhe::ServerKey = safe_deserialize_versioned_sks(&keys.sks_key).unwrap();
tfhe::set_server_key(sks);

let mut decrypted: Vec<(Vec<u8>, DecryptionResult)> = Vec::with_capacity(cts.len());
Expand Down
40 changes: 27 additions & 13 deletions fhevm-engine/coprocessor/src/tfhe_worker.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
use crate::utils::set_server_key_if_not_set;
use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys};
use fhevm_engine_common::types::SupportedFheCiphertexts;
use fhevm_engine_common::{
tfhe_ops::{current_ciphertext_version, perform_fhe_operation},
types::SupportedFheOperations,
};
use lazy_static::lazy_static;
use prometheus::{register_int_counter, IntCounter};
use sqlx::{postgres::PgListener, query, Acquire};
use std::{
collections::{BTreeSet, HashMap},
num::NonZeroUsize,
};
use lazy_static::lazy_static;

lazy_static! {
static ref WORKER_ERRORS_COUNTER: IntCounter =
register_int_counter!("coprocessor_worker_errors", "worker errors encountered").unwrap();
static ref WORK_ITEMS_POLL_COUNTER: IntCounter =
register_int_counter!("coprocessor_work_items_polls", "times work items are polled from database").unwrap();
static ref WORK_ITEMS_NOTIFICATIONS_COUNTER: IntCounter =
register_int_counter!("coprocessor_work_items_notifications", "times instant notifications for work items received from the database").unwrap();
static ref WORK_ITEMS_FOUND_COUNTER: IntCounter =
register_int_counter!("coprocessor_work_items_found", "work items queried from database").unwrap();
static ref WORK_ITEMS_ERRORS_COUNTER: IntCounter =
register_int_counter!("coprocessor_work_items_errors", "work items errored out during computation").unwrap();
static ref WORK_ITEMS_PROCESSED_COUNTER: IntCounter =
register_int_counter!("coprocessor_work_items_processed", "work items successfully processed and stored in the database").unwrap();
static ref WORK_ITEMS_POLL_COUNTER: IntCounter = register_int_counter!(
"coprocessor_work_items_polls",
"times work items are polled from database"
)
.unwrap();
static ref WORK_ITEMS_NOTIFICATIONS_COUNTER: IntCounter = register_int_counter!(
"coprocessor_work_items_notifications",
"times instant notifications for work items received from the database"
)
.unwrap();
static ref WORK_ITEMS_FOUND_COUNTER: IntCounter = register_int_counter!(
"coprocessor_work_items_found",
"work items queried from database"
)
.unwrap();
static ref WORK_ITEMS_ERRORS_COUNTER: IntCounter = register_int_counter!(
"coprocessor_work_items_errors",
"work items errored out during computation"
)
.unwrap();
static ref WORK_ITEMS_PROCESSED_COUNTER: IntCounter = register_int_counter!(
"coprocessor_work_items_processed",
"work items successfully processed and stored in the database"
)
.unwrap();
}

pub async fn run_tfhe_worker(
Expand Down Expand Up @@ -199,7 +213,7 @@ async fn tfhe_worker_cycle(
let keys = rk
.get(&w.tenant_id)
.expect("Can't get tenant key from cache");
set_server_key_if_not_set(w.tenant_id, &keys.sks);
tfhe::set_server_key(keys.sks.clone());
}

let mut deserialized_cts: Vec<SupportedFheCiphertexts> =
Expand Down
16 changes: 1 addition & 15 deletions fhevm-engine/coprocessor/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
cell::Cell,
collections::{BTreeSet, HashMap, HashSet},
};
use std::collections::{BTreeSet, HashMap, HashSet};

use fhevm_engine_common::types::{FhevmError, SupportedFheOperations};

Expand Down Expand Up @@ -324,14 +321,3 @@ fn test_multi_level_circular_dependency_detection() {
}
}
}

pub fn set_server_key_if_not_set(tenant_id: i32, sks: &tfhe::ServerKey) {
thread_local! {
static TFHE_TENANT_ID: Cell<i32> = Cell::new(-1);
}

if tenant_id != TFHE_TENANT_ID.get() {
tfhe::set_server_key(sks.clone());
TFHE_TENANT_ID.set(tenant_id);
}
}
Loading

0 comments on commit 0ebd723

Please sign in to comment.