From ba7f740e34f5c1fc6d6f511409b25359e2434ac3 Mon Sep 17 00:00:00 2001 From: David Kazlauskas Date: Wed, 21 Aug 2024 08:26:11 +0300 Subject: [PATCH] build: move out tfhe ops/types and errors to common crate --- fhevm-engine/Cargo.lock | 6 + fhevm-engine/coprocessor/Cargo.toml | 4 +- fhevm-engine/coprocessor/src/main.rs | 22 +- fhevm-engine/coprocessor/src/server.rs | 25 +- .../coprocessor/src/tests/operators.rs | 8 +- fhevm-engine/coprocessor/src/tfhe_worker.rs | 15 +- fhevm-engine/coprocessor/src/types.rs | 349 +--------- fhevm-engine/coprocessor/src/utils.rs | 17 +- fhevm-engine/fhevm-engine-common/Cargo.toml | 4 + fhevm-engine/fhevm-engine-common/src/lib.rs | 27 +- .../src/tfhe_ops.rs | 596 +++++++++--------- fhevm-engine/fhevm-engine-common/src/types.rs | 353 +++++++++++ 12 files changed, 745 insertions(+), 681 deletions(-) rename fhevm-engine/{coprocessor => fhevm-engine-common}/src/tfhe_ops.rs (93%) create mode 100644 fhevm-engine/fhevm-engine-common/src/types.rs diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index da2177f1..bdaea17c 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -796,6 +796,12 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] name = "fhevm-engine-common" version = "0.1.0" +dependencies = [ + "bincode", + "hex", + "strum", + "tfhe", +] [[package]] name = "fixedbitset" diff --git a/fhevm-engine/coprocessor/Cargo.toml b/fhevm-engine/coprocessor/Cargo.toml index 31cbdc08..c569f94c 100644 --- a/fhevm-engine/coprocessor/Cargo.toml +++ b/fhevm-engine/coprocessor/Cargo.toml @@ -20,11 +20,11 @@ regex = "1.10.5" lazy_static = "1.5.0" clap.workspace = true lru = "0.12.3" -bincode = "1.3.3" hex = "0.4" -strum = { version = "0.26", features = ["derive"] } bigdecimal = "0.4" fhevm-engine-common = { path = "../fhevm-engine-common" } +strum = { version = "0.26", features = ["derive"] } +bincode = "1.3.3" [dev-dependencies] testcontainers = "0.21" diff --git a/fhevm-engine/coprocessor/src/main.rs b/fhevm-engine/coprocessor/src/main.rs index 4e3aa407..e8252515 100644 --- a/fhevm-engine/coprocessor/src/main.rs +++ b/fhevm-engine/coprocessor/src/main.rs @@ -1,3 +1,4 @@ +use fhevm_engine_common::generate_fhe_keys; use tokio::task::JoinSet; mod cli; @@ -5,16 +6,11 @@ mod db_queries; mod server; #[cfg(test)] mod tests; -mod tfhe_ops; mod tfhe_worker; mod types; mod utils; fn main() { - - // TODO: remove, just to make sure it works - let _ = fhevm_engine_common::add(5, 5); - let args = crate::cli::parse_args(); assert!( args.work_items_batch_size < args.tenant_key_cache_size, @@ -22,7 +18,7 @@ fn main() { ); if args.generate_fhe_keys { - generate_fhe_keys(); + generate_dump_fhe_keys(); } else { start_runtime(args, None); } @@ -85,20 +81,16 @@ async fn async_main( Ok(()) } -fn generate_fhe_keys() { +fn generate_dump_fhe_keys() { let output_dir = "fhevm-keys"; println!("Generating keys..."); - let (client_key, server_key) = tfhe::generate_keys(tfhe::ConfigBuilder::default().build()); - let compact_key = tfhe::CompactPublicKey::new(&client_key); - let client_key = bincode::serialize(&client_key).unwrap(); - let server_key = bincode::serialize(&server_key).unwrap(); - let compact_key = bincode::serialize(&compact_key).unwrap(); + let keys = generate_fhe_keys(); println!("Creating directory {output_dir}"); std::fs::create_dir_all(output_dir).unwrap(); println!("Creating file {output_dir}/cks"); - std::fs::write(format!("{output_dir}/cks"), client_key).unwrap(); + std::fs::write(format!("{output_dir}/cks"), keys.client_key).unwrap(); println!("Creating file {output_dir}/pks"); - std::fs::write(format!("{output_dir}/pks"), compact_key).unwrap(); + std::fs::write(format!("{output_dir}/pks"), keys.compact_public_key).unwrap(); println!("Creating file {output_dir}/sks"); - std::fs::write(format!("{output_dir}/sks"), server_key).unwrap(); + std::fs::write(format!("{output_dir}/sks"), keys.server_key).unwrap(); } diff --git a/fhevm-engine/coprocessor/src/server.rs b/fhevm-engine/coprocessor/src/server.rs index 685f3cab..09816f35 100644 --- a/fhevm-engine/coprocessor/src/server.rs +++ b/fhevm-engine/coprocessor/src/server.rs @@ -1,8 +1,7 @@ use crate::db_queries::{check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db}; use crate::server::coprocessor::GenericResponse; -use crate::tfhe_ops::{ - self, check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_be_bytes, -}; +use fhevm_engine_common::tfhe_ops::{check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_be_bytes, deserialize_fhe_ciphertext}; +use fhevm_engine_common::types::FhevmError; use crate::types::CoprocessorError; use crate::utils::sort_computations_by_dependencies; use coprocessor::async_computation_input::Input; @@ -167,7 +166,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let mut decrypted: Vec = Vec::with_capacity(cts.len()); for ct in cts { let deserialized = - tfhe_ops::deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext) + deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext) .unwrap(); decrypted.push(DebugDecryptResponseSingle { output_type: ct.ciphertext_type as i32, @@ -205,7 +204,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let ciphertext_type: i16 = i_ct .ciphertext_type .try_into() - .map_err(|_e| CoprocessorError::UnknownFheType(i_ct.ciphertext_type))?; + .map_err(|_e| CoprocessorError::FhevmError(FhevmError::UnknownFheType(i_ct.ciphertext_type)))?; let _ = sqlx::query!(" INSERT INTO ciphertexts(tenant_id, handle, ciphertext, ciphertext_version, ciphertext_type) VALUES($1, $2, $3, $4, $5) @@ -254,6 +253,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let mut handle_types = Vec::with_capacity(comp.inputs.len()); let mut is_computation_scalar = false; let mut this_comp_inputs: Vec> = Vec::with_capacity(comp.inputs.len()); + let mut is_scalar_op_vec: Vec = Vec::with_capacity(comp.inputs.len()); for (idx, ih) in comp.inputs.iter().enumerate() { if let Some(input) = &ih.input { match input { @@ -263,27 +263,30 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ .expect("this must be found if operand is non scalar"); handle_types.push(*ct_type); this_comp_inputs.push(ih.clone()); + is_scalar_op_vec.push(false); } Input::Scalar(sc) => { is_computation_scalar = true; handle_types.push(-1); this_comp_inputs.push(sc.clone()); + is_scalar_op_vec.push(true); assert!(idx == 1, "we should have checked earlier that only second operand can be scalar"); } } } } - computations_inputs.push(this_comp_inputs); - are_comps_scalar.push(is_computation_scalar); // check before we insert computation that it has // to succeed according to the type system let output_type = check_fhe_operand_types( comp.operation, &handle_types, - is_computation_scalar, - &comp.inputs, - )?; + &this_comp_inputs, + &is_scalar_op_vec, + ).map_err(|e| CoprocessorError::FhevmError(e))?; + + computations_inputs.push(this_comp_inputs); + are_comps_scalar.push(is_computation_scalar); // fill in types with output handles that are computed as we go assert!(ct_types .insert(comp.output_handle.clone(), output_type) @@ -300,7 +303,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let fhe_operation: i16 = comp .operation .try_into() - .map_err(|_| CoprocessorError::UnknownFheOperation(comp.operation))?; + .map_err(|_| CoprocessorError::FhevmError(FhevmError::UnknownFheOperation(comp.operation)))?; let res = query!( " INSERT INTO computations(tenant_id, output_handle, dependencies, fhe_operation, is_completed, is_scalar) diff --git a/fhevm-engine/coprocessor/src/tests/operators.rs b/fhevm-engine/coprocessor/src/tests/operators.rs index d51be612..e3b7ba23 100644 --- a/fhevm-engine/coprocessor/src/tests/operators.rs +++ b/fhevm-engine/coprocessor/src/tests/operators.rs @@ -7,14 +7,12 @@ use crate::tests::utils::wait_until_all_ciphertexts_computed; use crate::{ server::coprocessor::{async_computation_input::Input, AsyncComputationInput}, tests::utils::{default_api_key, setup_test_app}, - tfhe_ops::{ - does_fhe_operation_support_both_encrypted_operands, does_fhe_operation_support_scalar, - }, - types::{FheOperationType, SupportedFheOperations}, }; use bigdecimal::num_bigint::BigInt; -use std::{ops::Not, str::FromStr}; +use fhevm_engine_common::tfhe_ops::{does_fhe_operation_support_both_encrypted_operands, does_fhe_operation_support_scalar}; +use fhevm_engine_common::types::{FheOperationType, SupportedFheOperations}; use strum::IntoEnumIterator; +use std::{ops::Not, str::FromStr}; use tonic::metadata::MetadataValue; struct BinaryOperatorTestCase { diff --git a/fhevm-engine/coprocessor/src/tfhe_worker.rs b/fhevm-engine/coprocessor/src/tfhe_worker.rs index 00191183..2dbadbb0 100644 --- a/fhevm-engine/coprocessor/src/tfhe_worker.rs +++ b/fhevm-engine/coprocessor/src/tfhe_worker.rs @@ -1,7 +1,8 @@ -use crate::tfhe_ops::{ +use fhevm_engine_common::tfhe_ops::{ current_ciphertext_version, deserialize_fhe_ciphertext, perform_fhe_operation, }; -use crate::types::{SupportedFheCiphertexts, TfheTenantKeys}; +use crate::types::TfheTenantKeys; +use fhevm_engine_common::types::SupportedFheCiphertexts; use sqlx::{postgres::PgListener, query, Acquire}; use std::{ cell::Cell, @@ -223,13 +224,19 @@ async fn tfhe_worker_cycle( } else { deserialized_cts.push( deserialize_fhe_ciphertext(*ct_type, ct_bytes.as_slice()) - .map_err(|e| (e, w.tenant_id, w.output_handle.clone()))?, + .map_err(|e| { + let err: Box = Box::new(e); + (err, w.tenant_id, w.output_handle.clone()) + })?, ); } } let res = perform_fhe_operation(w.fhe_operation, &deserialized_cts) - .map_err(|e| (e, w.tenant_id, w.output_handle.clone()))?; + .map_err(|e| { + let err: Box = Box::new(e); + (err, w.tenant_id, w.output_handle.clone()) + })?; let (db_type, db_bytes) = res.serialize(); Ok((w, db_type, db_bytes)) diff --git a/fhevm-engine/coprocessor/src/types.rs b/fhevm-engine/coprocessor/src/types.rs index 54037f18..1f814b6d 100644 --- a/fhevm-engine/coprocessor/src/types.rs +++ b/fhevm-engine/coprocessor/src/types.rs @@ -1,26 +1,19 @@ -use tfhe::{integer::U256, prelude::FheDecrypt}; +use fhevm_engine_common::types::FhevmError; #[derive(Debug)] pub enum CoprocessorError { DbError(sqlx::Error), Unauthorized, - UnknownFheOperation(i32), - UnknownFheType(i32), + FhevmError(FhevmError), DuplicateOutputHandleInBatch(String), CiphertextHandleLongerThan64Bytes, CiphertextHandleMustBeAtLeast1Byte(String), UnexistingInputCiphertextsFound(Vec), OutputHandleIsAlsoInputHandle(String), - UnknownCiphertextType(i16), ComputationInputIsUndefined { computation_output_handle: String, computation_inputs_index: usize, }, - OnlySecondOperandCanBeScalar { - computation_output_handle: String, - scalar_input_index: usize, - only_allowed_scalar_input_index: usize, - }, TooManyCiphertextsInBatch { maximum_allowed: usize, got: usize, @@ -29,121 +22,43 @@ pub enum CoprocessorError { uncomputable_output_handle: String, uncomputable_handle_dependency: String, }, - UnexpectedOperandCountForFheOperation { - fhe_operation: i32, - fhe_operation_name: String, - expected_operands: usize, - got_operands: usize, - }, - FheIfThenElseUnexpectedOperandTypes { - fhe_operation: i32, - fhe_operation_name: String, - first_operand_type: i16, - first_expected_operand_type: i16, - first_expected_operand_type_name: String, - }, - FheIfThenElseMismatchingSecondAndThirdOperatorTypes { - fhe_operation: i32, - fhe_operation_name: String, - second_operand_type: i16, - third_operand_type: i16, - }, - UnexpectedCastOperandTypes { - fhe_operation: i32, - fhe_operation_name: String, - expected_operator_combination: Vec, - }, - UnexpectedCastOperandSizeForScalarOperand { - fhe_operation: i32, - fhe_operation_name: String, - expected_scalar_operand_bytes: usize, - got_bytes: usize, - }, - FheOperationDoesntSupportScalar { - fhe_operation: i32, - fhe_operation_name: String, - scalar_requested: bool, - scalar_supported: bool, - }, - FheOperationDoesntHaveUniformTypesAsInput { - fhe_operation: i32, - fhe_operation_name: String, - operand_types: Vec, - }, - FheOperationScalarDivisionByZero { - lhs_handle: String, - rhs_value: String, - fhe_operation: i32, - fhe_operation_name: String, - }, } impl std::fmt::Display for CoprocessorError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - CoprocessorError::DbError(dbe) => { + Self::DbError(dbe) => { write!(f, "Coprocessor db error: {:?}", dbe) } - CoprocessorError::Unauthorized => { + Self::Unauthorized => { write!(f, "API key unknown/invalid/not provided") } - CoprocessorError::UnknownFheOperation(op) => { - write!(f, "Unknown fhe operation: {}", op) - } - CoprocessorError::UnknownFheType(op) => { - write!(f, "Unknown fhe type: {}", op) - } - CoprocessorError::DuplicateOutputHandleInBatch(op) => { + Self::DuplicateOutputHandleInBatch(op) => { write!(f, "Duplicate output handle in ciphertext batch: {}", op) } - CoprocessorError::CiphertextHandleLongerThan64Bytes => { + Self::CiphertextHandleLongerThan64Bytes => { write!(f, "Found ciphertext handle longer than 64 bytes") } - CoprocessorError::CiphertextHandleMustBeAtLeast1Byte(handle) => { + Self::CiphertextHandleMustBeAtLeast1Byte(handle) => { write!(f, "Found ciphertext handle less than 4 bytes: {handle}") } - CoprocessorError::UnexistingInputCiphertextsFound(handles) => { + Self::UnexistingInputCiphertextsFound(handles) => { write!(f, "Ciphertexts not found: {:?}", handles) } - CoprocessorError::OutputHandleIsAlsoInputHandle(handle) => { + Self::OutputHandleIsAlsoInputHandle(handle) => { write!( f, "Output handle is also on of the input handles: {}", handle ) } - CoprocessorError::UnknownCiphertextType(the_type) => { - write!(f, "Unknown input ciphertext type: {}", the_type) - } - CoprocessorError::UnexpectedOperandCountForFheOperation { - fhe_operation, - fhe_operation_name, - expected_operands, - got_operands, - } => { - write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) received unexpected operand count, expected: {expected_operands}, received: {got_operands}") - } - CoprocessorError::FheOperationDoesntSupportScalar { - fhe_operation, - fhe_operation_name, - .. - } => { - write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) doesn't support scalar computation") - } - CoprocessorError::FheOperationDoesntHaveUniformTypesAsInput { - fhe_operation, - fhe_operation_name, - operand_types, - } => { - write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) expects uniform types as input, received: {:?}", operand_types) - } - CoprocessorError::CiphertextComputationDependencyLoopDetected { + Self::CiphertextComputationDependencyLoopDetected { uncomputable_output_handle, uncomputable_handle_dependency, } => { write!(f, "fhe computation with output handle {uncomputable_output_handle} with dependency {:?} has circular dependency and is uncomputable", uncomputable_handle_dependency) } - CoprocessorError::TooManyCiphertextsInBatch { + Self::TooManyCiphertextsInBatch { maximum_allowed, got, } => { @@ -152,47 +67,14 @@ impl std::fmt::Display for CoprocessorError { "maximum ciphertexts exceeded in batch, maximum: {maximum_allowed}, got: {got}" ) } - CoprocessorError::FheOperationScalarDivisionByZero { - lhs_handle, - rhs_value, - fhe_operation, - fhe_operation_name, - } => { - write!(f, "zero on the right side of scalar division, lhs handle: {lhs_handle}, rhs value: {rhs_value}, fhe operation: {fhe_operation} fhe operation name:{fhe_operation_name}") - } - CoprocessorError::ComputationInputIsUndefined { + Self::ComputationInputIsUndefined { computation_output_handle, computation_inputs_index, } => { write!(f, "computation has undefined input, output handle: {computation_output_handle}, input index: {computation_inputs_index}") } - CoprocessorError::OnlySecondOperandCanBeScalar { - computation_output_handle, - scalar_input_index, - only_allowed_scalar_input_index, - } => { - write!(f, "computation has scalar operand which is not the second operand, output handle: {computation_output_handle}, scalar input index: {scalar_input_index}, only allowed scalar input index: {only_allowed_scalar_input_index}") - } - CoprocessorError::UnexpectedCastOperandTypes { - fhe_operation, - fhe_operation_name, - expected_operator_combination, - } => { - write!(f, "unexpected operand types for cast, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, expected operand combination: {:?}", expected_operator_combination) - } - CoprocessorError::UnexpectedCastOperandSizeForScalarOperand { - fhe_operation, - fhe_operation_name, - expected_scalar_operand_bytes, - got_bytes, - } => { - write!(f, "unexpected operand size for cast, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, expected bytes: {}, got bytes: {}", expected_scalar_operand_bytes, got_bytes) - } - CoprocessorError::FheIfThenElseUnexpectedOperandTypes { fhe_operation, fhe_operation_name, first_operand_type, first_expected_operand_type, .. } => { - write!(f, "fhe if then else first operand should always be FheBool, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, first operand type: {first_operand_type}, first operand expected type: {first_expected_operand_type}") - } - CoprocessorError::FheIfThenElseMismatchingSecondAndThirdOperatorTypes { fhe_operation, fhe_operation_name, second_operand_type, third_operand_type } => { - write!(f, "fhe if then else second and third operand types don't match, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, second operand type: {second_operand_type}, third operand type: {third_operand_type}") + Self::FhevmError(e) => { + write!(f, "fhevm error: {:?}", e) } } } @@ -218,206 +100,3 @@ pub struct TfheTenantKeys { #[allow(dead_code)] pub pks: tfhe::CompactPublicKey, } - -pub enum SupportedFheCiphertexts { - FheBool(tfhe::FheBool), - FheUint8(tfhe::FheUint8), - FheUint16(tfhe::FheUint16), - FheUint32(tfhe::FheUint32), - FheUint64(tfhe::FheUint64), - Scalar(U256), -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::EnumIter)] -#[repr(i8)] -pub enum SupportedFheOperations { - FheAdd = 0, - FheSub = 1, - FheMul = 2, - FheDiv = 3, - FheRem = 4, - FheBitAnd = 5, - FheBitOr = 6, - FheBitXor = 7, - FheShl = 8, - FheShr = 9, - FheRotl = 10, - FheRotr = 11, - FheEq = 12, - FheNe = 13, - FheGe = 14, - FheGt = 15, - FheLe = 16, - FheLt = 17, - FheMin = 18, - FheMax = 19, - FheNeg = 20, - FheNot = 21, - FheCast = 30, - FheIfThenElse = 31, -} - -#[derive(PartialEq, Eq)] -pub enum FheOperationType { - Binary, - Unary, - Other, -} - -impl SupportedFheCiphertexts { - pub fn serialize(&self) -> (i16, Vec) { - let type_num = self.type_num(); - match self { - SupportedFheCiphertexts::FheBool(v) => (type_num, bincode::serialize(v).unwrap()), - SupportedFheCiphertexts::FheUint8(v) => (type_num, bincode::serialize(v).unwrap()), - SupportedFheCiphertexts::FheUint16(v) => (type_num, bincode::serialize(v).unwrap()), - SupportedFheCiphertexts::FheUint32(v) => (type_num, bincode::serialize(v).unwrap()), - SupportedFheCiphertexts::FheUint64(v) => (type_num, bincode::serialize(v).unwrap()), - SupportedFheCiphertexts::Scalar(_) => { - panic!("we should never need to serialize scalar") - } - } - } - - pub fn type_num(&self) -> i16 { - match self { - SupportedFheCiphertexts::FheBool(_) => 1, - SupportedFheCiphertexts::FheUint8(_) => 2, - SupportedFheCiphertexts::FheUint16(_) => 3, - SupportedFheCiphertexts::FheUint32(_) => 4, - SupportedFheCiphertexts::FheUint64(_) => 5, - SupportedFheCiphertexts::Scalar(_) => { - panic!("we should never need to serialize scalar") - } - } - } - - pub fn decrypt(&self, client_key: &tfhe::ClientKey) -> String { - match self { - SupportedFheCiphertexts::FheBool(v) => v.decrypt(client_key).to_string(), - SupportedFheCiphertexts::FheUint8(v) => { - FheDecrypt::::decrypt(v, client_key).to_string() - } - SupportedFheCiphertexts::FheUint16(v) => { - FheDecrypt::::decrypt(v, client_key).to_string() - } - SupportedFheCiphertexts::FheUint32(v) => { - FheDecrypt::::decrypt(v, client_key).to_string() - } - SupportedFheCiphertexts::FheUint64(v) => { - FheDecrypt::::decrypt(v, client_key).to_string() - } - SupportedFheCiphertexts::Scalar(v) => { - let (l, h) = v.to_low_high_u128(); - format!("{l}{h}") - } - } - } -} - -impl SupportedFheOperations { - pub fn op_type(&self) -> FheOperationType { - match self { - SupportedFheOperations::FheAdd - | SupportedFheOperations::FheSub - | SupportedFheOperations::FheMul - | SupportedFheOperations::FheDiv - | SupportedFheOperations::FheRem - | SupportedFheOperations::FheBitAnd - | SupportedFheOperations::FheBitOr - | SupportedFheOperations::FheBitXor - | SupportedFheOperations::FheShl - | SupportedFheOperations::FheShr - | SupportedFheOperations::FheRotl - | SupportedFheOperations::FheRotr - | SupportedFheOperations::FheEq - | SupportedFheOperations::FheNe - | SupportedFheOperations::FheGe - | SupportedFheOperations::FheGt - | SupportedFheOperations::FheLe - | SupportedFheOperations::FheLt - | SupportedFheOperations::FheMin - | SupportedFheOperations::FheMax => FheOperationType::Binary, - SupportedFheOperations::FheNot | SupportedFheOperations::FheNeg => { - FheOperationType::Unary - } - SupportedFheOperations::FheIfThenElse | SupportedFheOperations::FheCast => { - FheOperationType::Other - } - } - } - - pub fn is_comparison(&self) -> bool { - match self { - SupportedFheOperations::FheEq - | SupportedFheOperations::FheNe - | SupportedFheOperations::FheGe - | SupportedFheOperations::FheGt - | SupportedFheOperations::FheLe - | SupportedFheOperations::FheLt => true, - _ => false, - } - } -} - -impl TryFrom for SupportedFheOperations { - type Error = CoprocessorError; - - fn try_from(value: i16) -> Result { - let res = match value { - 0 => Ok(SupportedFheOperations::FheAdd), - 1 => Ok(SupportedFheOperations::FheSub), - 2 => Ok(SupportedFheOperations::FheMul), - 3 => Ok(SupportedFheOperations::FheDiv), - 4 => Ok(SupportedFheOperations::FheRem), - 5 => Ok(SupportedFheOperations::FheBitAnd), - 6 => Ok(SupportedFheOperations::FheBitOr), - 7 => Ok(SupportedFheOperations::FheBitXor), - 8 => Ok(SupportedFheOperations::FheShl), - 9 => Ok(SupportedFheOperations::FheShr), - 10 => Ok(SupportedFheOperations::FheRotl), - 11 => Ok(SupportedFheOperations::FheRotr), - 12 => Ok(SupportedFheOperations::FheEq), - 13 => Ok(SupportedFheOperations::FheNe), - 14 => Ok(SupportedFheOperations::FheGe), - 15 => Ok(SupportedFheOperations::FheGt), - 16 => Ok(SupportedFheOperations::FheLe), - 17 => Ok(SupportedFheOperations::FheLt), - 18 => Ok(SupportedFheOperations::FheMin), - 19 => Ok(SupportedFheOperations::FheMax), - 20 => Ok(SupportedFheOperations::FheNeg), - 21 => Ok(SupportedFheOperations::FheNot), - 30 => Ok(SupportedFheOperations::FheCast), - 31 => Ok(SupportedFheOperations::FheIfThenElse), - _ => Err(CoprocessorError::UnknownFheOperation(value as i32)), - }; - - // ensure we're always having the same value serialized back and forth - if let Ok(v) = &res { - assert_eq!(v.clone() as i16, value); - } - - res - } -} - -// we get i32 from protobuf (smaller types unsupported) -// but in database we store i16 -impl TryFrom for SupportedFheOperations { - type Error = CoprocessorError; - - fn try_from(value: i32) -> Result { - let initial_value: i16 = value - .try_into() - .map_err(|_| CoprocessorError::UnknownFheOperation(value))?; - - let final_value: Result = initial_value.try_into(); - final_value - } -} - -impl From for i16 { - fn from(value: SupportedFheOperations) -> Self { - value as i16 - } -} diff --git a/fhevm-engine/coprocessor/src/utils.rs b/fhevm-engine/coprocessor/src/utils.rs index 7e629ab7..438ea61f 100644 --- a/fhevm-engine/coprocessor/src/utils.rs +++ b/fhevm-engine/coprocessor/src/utils.rs @@ -1,5 +1,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; +use fhevm_engine_common::types::FhevmError; + #[cfg(test)] use crate::server::coprocessor::AsyncComputationInput; @@ -69,14 +71,13 @@ pub fn sort_computations_by_dependencies<'a>( Input::Scalar(sc_bytes) => { check_valid_ciphertext_handle(&sc_bytes)?; if dep_idx != 1 { - return Err(CoprocessorError::OnlySecondOperandCanBeScalar { - computation_output_handle: format!( - "0x{}", - hex::encode(&comp.output_handle) - ), - scalar_input_index: dep_idx, - only_allowed_scalar_input_index: 1, - }); + // TODO: remove wrapping after refactor + return Err(CoprocessorError::FhevmError( + FhevmError::FheOperationOnlySecondOperandCanBeScalar { + scalar_input_index: dep_idx, + only_allowed_scalar_input_index: 1, + } + )); } is_scalar_operand = true; } diff --git a/fhevm-engine/fhevm-engine-common/Cargo.toml b/fhevm-engine/fhevm-engine-common/Cargo.toml index 5cdd35bf..f072e6e2 100644 --- a/fhevm-engine/fhevm-engine-common/Cargo.toml +++ b/fhevm-engine/fhevm-engine-common/Cargo.toml @@ -4,3 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +tfhe.workspace = true +strum = { version = "0.26", features = ["derive"] } +bincode = "1.3.3" +hex = "0.4" \ No newline at end of file diff --git a/fhevm-engine/fhevm-engine-common/src/lib.rs b/fhevm-engine/fhevm-engine-common/src/lib.rs index 7d12d9af..4f87f403 100644 --- a/fhevm-engine/fhevm-engine-common/src/lib.rs +++ b/fhevm-engine/fhevm-engine-common/src/lib.rs @@ -1,14 +1,17 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; +pub mod types; +pub mod tfhe_ops; - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } +pub struct FhevmKeys { + pub server_key: Vec, + pub client_key: Vec, + pub compact_public_key: Vec, } + +pub fn generate_fhe_keys() -> FhevmKeys { + let (client_key, server_key) = tfhe::generate_keys(tfhe::ConfigBuilder::default().build()); + let compact_key = tfhe::CompactPublicKey::new(&client_key); + let client_key = bincode::serialize(&client_key).unwrap(); + let server_key = bincode::serialize(&server_key).unwrap(); + let compact_public_key = bincode::serialize(&compact_key).unwrap(); + FhevmKeys { server_key, client_key, compact_public_key } +} \ No newline at end of file diff --git a/fhevm-engine/coprocessor/src/tfhe_ops.rs b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs similarity index 93% rename from fhevm-engine/coprocessor/src/tfhe_ops.rs rename to fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs index 96c4b92b..ced6ca93 100644 --- a/fhevm-engine/coprocessor/src/tfhe_ops.rs +++ b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs @@ -1,3 +1,4 @@ +use crate::types::{FheOperationType, FhevmError, SupportedFheCiphertexts, SupportedFheOperations}; use tfhe::{ prelude::{ CastInto, FheEq, FheMax, FheMin, FheOrd, FheTryTrivialEncrypt, IfThenElse, RotateLeft, RotateRight @@ -5,19 +6,319 @@ use tfhe::{ FheBool, FheUint16, FheUint32, FheUint64, FheUint8, }; -use crate::{ - server::coprocessor::{async_computation_input::Input, AsyncComputationInput}, - types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations}, -}; + +pub fn deserialize_fhe_ciphertext( + input_type: i16, + input_bytes: &[u8], +) -> Result { + match input_type { + 1 => { + let v: tfhe::FheBool = bincode::deserialize(input_bytes).map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheBool(v)) + } + 2 => { + let v: tfhe::FheUint8 = bincode::deserialize(input_bytes).map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheUint8(v)) + } + 3 => { + let v: tfhe::FheUint16 = bincode::deserialize(input_bytes).map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheUint16(v)) + } + 4 => { + let v: tfhe::FheUint32 = bincode::deserialize(input_bytes).map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheUint32(v)) + } + 5 => { + let v: tfhe::FheUint64 = bincode::deserialize(input_bytes).map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheUint64(v)) + } + _ => { + return Err(FhevmError::UnknownCiphertextType( + input_type, + )); + } + } +} + +/// Function assumes encryption key already set +pub fn debug_trivial_encrypt_be_bytes( + output_type: i16, + input_bytes: &[u8], +) -> SupportedFheCiphertexts { + match output_type { + 1 => SupportedFheCiphertexts::FheBool( + FheBool::try_encrypt_trivial(input_bytes[0] > 0).unwrap(), + ), + 2 => SupportedFheCiphertexts::FheUint8( + FheUint8::try_encrypt_trivial(input_bytes[0]).unwrap(), + ), + 3 => { + let mut padded: [u8; 2] = [0; 2]; + let padded_len = padded.len(); + let copy_from = padded_len - input_bytes.len(); + let len = padded.len().min(input_bytes.len()); + padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); + let res = u16::from_be_bytes(padded); + SupportedFheCiphertexts::FheUint16(FheUint16::try_encrypt_trivial(res).unwrap()) + } + 4 => { + let mut padded: [u8; 4] = [0; 4]; + let padded_len = padded.len(); + let copy_from = padded_len - input_bytes.len(); + let len = padded.len().min(input_bytes.len()); + padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); + let res: u32 = u32::from_be_bytes(padded); + SupportedFheCiphertexts::FheUint32(FheUint32::try_encrypt_trivial(res).unwrap()) + } + 5 => { + let mut padded: [u8; 8] = [0; 8]; + let padded_len = padded.len(); + let copy_from = padded_len - input_bytes.len(); + let len = padded.len().min(input_bytes.len()); + padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); + let res: u64 = u64::from_be_bytes(padded); + SupportedFheCiphertexts::FheUint64(FheUint64::try_encrypt_trivial(res).unwrap()) + } + other => { + panic!("Unknown input type for trivial encryption: {other}") + } + } +} pub fn current_ciphertext_version() -> i16 { 1 } +// return output ciphertext type +pub fn check_fhe_operand_types( + fhe_operation: i32, + input_types: &[i16], + input_handles: &[Vec], + is_input_handle_scalar: &[bool], +) -> Result { + assert_eq!(input_handles.len(), is_input_handle_scalar.len()); + + let fhe_op: SupportedFheOperations = fhe_operation.try_into()?; + + let scalar_operands = is_input_handle_scalar.iter().enumerate() + .filter(|(_, is_scalar)| **is_scalar) + .collect::>(); + + let is_scalar = scalar_operands.len() > 0; + + if scalar_operands.len() > 1 { + return Err(FhevmError::FheOperationOnlyOneOperandCanBeScalar { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + scalar_operand_count: scalar_operands.len(), + max_scalar_operands: 1, + }); + } + + if is_scalar { + assert_eq!( + scalar_operands.len(), 1, + "We checked already that not more than 1 scalar operand can be present" + ); + + if !does_fhe_operation_support_scalar(&fhe_op) { + return Err(FhevmError::FheOperationDoesntSupportScalar { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + scalar_requested: is_scalar, + scalar_supported: false, + }); + } + + let scalar_input_index =scalar_operands[0].0; + if scalar_input_index != 1 { + return Err(FhevmError::FheOperationOnlySecondOperandCanBeScalar { + scalar_input_index, + only_allowed_scalar_input_index: 1, + }); + } + } + + match fhe_op.op_type() { + FheOperationType::Binary => { + let expected_operands = 2; + if input_types.len() != expected_operands { + return Err(FhevmError::UnexpectedOperandCountForFheOperation { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_operands, + got_operands: input_types.len(), + }); + } + + if !is_scalar && input_types[0] != input_types[1] { + return Err( + FhevmError::FheOperationDoesntHaveUniformTypesAsInput { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + operand_types: input_types.to_vec(), + }, + ); + } + + // special case for div operation, rhs for scalar must be zero + if is_scalar && fhe_op == SupportedFheOperations::FheDiv { + let all_zeroes = input_handles[1].iter().all(|i| *i == 0u8); + if all_zeroes { + return Err(FhevmError::FheOperationScalarDivisionByZero { + lhs_handle: format!("0x{}", hex::encode(&input_handles[0])), + rhs_value: format!("0x{}", hex::encode(&input_handles[1])), + fhe_operation, + fhe_operation_name: format!("{:?}", SupportedFheOperations::FheDiv), + }); + } + } + + if fhe_op.is_comparison() { + return Ok(1); // fhe bool type + } + + return Ok(input_types[0]); + } + FheOperationType::Unary => { + let expected_operands = 1; + if input_types.len() != expected_operands { + return Err(FhevmError::UnexpectedOperandCountForFheOperation { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_operands, + got_operands: input_types.len(), + }); + } + + return Ok(input_types[0]); + } + FheOperationType::Other => { + match &fhe_op { + // two ops + uniform types branch + // what about scalar compute? + SupportedFheOperations::FheIfThenElse => { + let expected_operands = 3; + if input_types.len() != expected_operands { + return Err(FhevmError::UnexpectedOperandCountForFheOperation { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_operands, + got_operands: input_types.len(), + }); + } + + // TODO: figure out typing system with constants + let fhe_bool_type = 1; + if input_types[0] != fhe_bool_type { + return Err(FhevmError::FheIfThenElseUnexpectedOperandTypes { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + first_expected_operand_type: fhe_bool_type, + first_expected_operand_type_name: "FheBool".to_string(), + first_operand_type: input_types[0], + }); + } + + if input_types[1] != input_types[2] { + return Err(FhevmError::FheIfThenElseMismatchingSecondAndThirdOperatorTypes { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + second_operand_type: input_types[1], + third_operand_type: input_types[2], + }); + } + + Ok(input_types[1]) + } + SupportedFheOperations::FheCast => { + let expected_operands = 2; + if input_types.len() != expected_operands { + return Err(FhevmError::UnexpectedOperandCountForFheOperation { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_operands, + got_operands: input_types.len(), + }); + } + + match (is_input_handle_scalar[0], is_input_handle_scalar[1]) { + (false, true) => { + let op = &input_handles[1]; + if op.len() != 1 { + return Err(FhevmError::UnexpectedCastOperandSizeForScalarOperand { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_scalar_operand_bytes: 1, + got_bytes: op.len(), + }); + } + + let output_type = op[0] as i16; + validate_fhe_type(output_type)?; + Ok(output_type) + } + (other_left, other_right) => { + let bool_to_op = |inp| { + (if inp { "scalar" } else { "handle" }).to_string() + }; + + return Err(FhevmError::UnexpectedCastOperandTypes { + fhe_operation, + fhe_operation_name: format!("{:?}", fhe_op), + expected_operator_combination: vec![ + "handle".to_string(), + "scalar".to_string(), + ], + got_operand_combination: vec![ + bool_to_op(other_left), + bool_to_op(other_right), + ], + }); + } + } + } + other => { + panic!("Unexpected branch: {:?}", other) + } + } + } + } +} + +pub fn validate_fhe_type(input_type: i16) -> Result<(), FhevmError> { + match input_type { + 1 | 2 | 3 | 4 | 5 => Ok(()), + _ => Err(FhevmError::UnknownCiphertextType(input_type)), + } +} + +pub fn does_fhe_operation_support_scalar(op: &SupportedFheOperations) -> bool { + match op.op_type() { + FheOperationType::Binary => true, + FheOperationType::Unary => false, + FheOperationType::Other => { + match op { + // second operand determines which type to cast to + SupportedFheOperations::FheCast => true, + _ => false, + } + } + } +} + +// add operations here that don't support both encrypted operands +pub fn does_fhe_operation_support_both_encrypted_operands(op: &SupportedFheOperations) -> bool { + match op { + SupportedFheOperations::FheDiv => false, + _ => true, + } +} + pub fn perform_fhe_operation( fhe_operation: i16, input_operands: &[SupportedFheCiphertexts], -) -> Result> { +) -> Result { let fhe_operation: SupportedFheOperations = fhe_operation.try_into()?; match fhe_operation { SupportedFheOperations::FheAdd => { @@ -1131,287 +1432,4 @@ pub fn perform_fhe_operation( } }, } -} - -/// Function assumes encryption key already set -pub fn debug_trivial_encrypt_be_bytes( - output_type: i16, - input_bytes: &[u8], -) -> SupportedFheCiphertexts { - match output_type { - 1 => SupportedFheCiphertexts::FheBool( - FheBool::try_encrypt_trivial(input_bytes[0] > 0).unwrap(), - ), - 2 => SupportedFheCiphertexts::FheUint8( - FheUint8::try_encrypt_trivial(input_bytes[0]).unwrap(), - ), - 3 => { - let mut padded: [u8; 2] = [0; 2]; - let padded_len = padded.len(); - let copy_from = padded_len - input_bytes.len(); - let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); - let res = u16::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint16(FheUint16::try_encrypt_trivial(res).unwrap()) - } - 4 => { - let mut padded: [u8; 4] = [0; 4]; - let padded_len = padded.len(); - let copy_from = padded_len - input_bytes.len(); - let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); - let res: u32 = u32::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint32(FheUint32::try_encrypt_trivial(res).unwrap()) - } - 5 => { - let mut padded: [u8; 8] = [0; 8]; - let padded_len = padded.len(); - let copy_from = padded_len - input_bytes.len(); - let len = padded.len().min(input_bytes.len()); - padded[copy_from..padded_len].copy_from_slice(&input_bytes[0..len]); - let res: u64 = u64::from_be_bytes(padded); - SupportedFheCiphertexts::FheUint64(FheUint64::try_encrypt_trivial(res).unwrap()) - } - other => { - panic!("Unknown input type for trivial encryption: {other}") - } - } -} - -pub fn validate_fhe_type(input_type: i16) -> Result<(), CoprocessorError> { - match input_type { - 1 | 2 | 3 | 4 | 5 => Ok(()), - _ => Err(CoprocessorError::UnknownCiphertextType(input_type)), - } -} - -pub fn deserialize_fhe_ciphertext( - input_type: i16, - input_bytes: &[u8], -) -> Result> { - match input_type { - 1 => { - let v: tfhe::FheBool = bincode::deserialize(input_bytes)?; - Ok(SupportedFheCiphertexts::FheBool(v)) - } - 2 => { - let v: tfhe::FheUint8 = bincode::deserialize(input_bytes)?; - Ok(SupportedFheCiphertexts::FheUint8(v)) - } - 3 => { - let v: tfhe::FheUint16 = bincode::deserialize(input_bytes)?; - Ok(SupportedFheCiphertexts::FheUint16(v)) - } - 4 => { - let v: tfhe::FheUint32 = bincode::deserialize(input_bytes)?; - Ok(SupportedFheCiphertexts::FheUint32(v)) - } - 5 => { - let v: tfhe::FheUint64 = bincode::deserialize(input_bytes)?; - Ok(SupportedFheCiphertexts::FheUint64(v)) - } - _ => { - return Err(Box::new(CoprocessorError::UnknownCiphertextType( - input_type, - ))); - } - } -} - -fn encode_comp_input_to_handle(input: &AsyncComputationInput) -> String { - match &input.input { - Some(Input::Scalar(sc)) => { - format!("0x{}", hex::encode(sc)) - } - Some(Input::InputHandle(handle)) => { - format!("0x{}", hex::encode(handle)) - } - None => panic!("we assume we get something here"), - } -} - -// return output ciphertext type -pub fn check_fhe_operand_types( - fhe_operation: i32, - input_types: &[i16], - is_scalar: bool, - input_handles: &[AsyncComputationInput], -) -> Result { - let fhe_op: SupportedFheOperations = fhe_operation.try_into()?; - - if is_scalar && !does_fhe_operation_support_scalar(&fhe_op) { - return Err(CoprocessorError::FheOperationDoesntSupportScalar { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - scalar_requested: is_scalar, - scalar_supported: false, - }); - } - - match fhe_op.op_type() { - FheOperationType::Binary => { - let expected_operands = 2; - if input_types.len() != expected_operands { - return Err(CoprocessorError::UnexpectedOperandCountForFheOperation { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_operands, - got_operands: input_types.len(), - }); - } - - if !is_scalar && input_types[0] != input_types[1] { - return Err( - CoprocessorError::FheOperationDoesntHaveUniformTypesAsInput { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - operand_types: input_types.to_vec(), - }, - ); - } - - // special case for div operation, rhs for scalar must be zero - if is_scalar && fhe_op == SupportedFheOperations::FheDiv { - if let Some(Input::Scalar(sc)) = &input_handles[1].input { - let all_zeroes = sc.iter().all(|i| *i == 0u8); - if all_zeroes { - return Err(CoprocessorError::FheOperationScalarDivisionByZero { - lhs_handle: encode_comp_input_to_handle(&input_handles[0]), - rhs_value: encode_comp_input_to_handle(&input_handles[1]), - fhe_operation, - fhe_operation_name: format!("{:?}", SupportedFheOperations::FheDiv), - }); - } - } else { - panic!("rhs operand must be scalar here") - } - } - - if fhe_op.is_comparison() { - return Ok(1); // fhe bool type - } - - return Ok(input_types[0]); - } - FheOperationType::Unary => { - let expected_operands = 1; - if input_types.len() != expected_operands { - return Err(CoprocessorError::UnexpectedOperandCountForFheOperation { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_operands, - got_operands: input_types.len(), - }); - } - - return Ok(input_types[0]); - } - FheOperationType::Other => { - match &fhe_op { - // two ops + uniform types branch - // what about scalar compute? - SupportedFheOperations::FheIfThenElse => { - let expected_operands = 3; - if input_types.len() != expected_operands { - return Err(CoprocessorError::UnexpectedOperandCountForFheOperation { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_operands, - got_operands: input_types.len(), - }); - } - - // TODO: figure out typing system with constants - let fhe_bool_type = 1; - if input_types[0] != fhe_bool_type { - return Err(CoprocessorError::FheIfThenElseUnexpectedOperandTypes { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - first_expected_operand_type: fhe_bool_type, - first_expected_operand_type_name: "FheBool".to_string(), - first_operand_type: input_types[0], - }); - } - - if input_types[1] != input_types[2] { - return Err(CoprocessorError::FheIfThenElseMismatchingSecondAndThirdOperatorTypes { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - second_operand_type: input_types[1], - third_operand_type: input_types[2], - }); - } - - Ok(input_types[1]) - } - SupportedFheOperations::FheCast => { - let expected_operands = 2; - if input_types.len() != expected_operands { - return Err(CoprocessorError::UnexpectedOperandCountForFheOperation { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_operands, - got_operands: input_types.len(), - }); - } - - match (&input_handles[0].input, &input_handles[1].input) { - (Some(a), Some(b)) => match (a, b) { - (Input::InputHandle(_ih), Input::Scalar(op)) => { - if op.len() != 1 { - return Err(CoprocessorError::UnexpectedCastOperandSizeForScalarOperand { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_scalar_operand_bytes: 1, - got_bytes: op.len(), - }); - } - - let output_type = op[0] as i16; - validate_fhe_type(output_type)?; - Ok(output_type) - } - _ => { - return Err(CoprocessorError::UnexpectedCastOperandTypes { - fhe_operation, - fhe_operation_name: format!("{:?}", fhe_op), - expected_operator_combination: vec![ - "handle".to_string(), - "scalar".to_string(), - ], - }); - } - }, - _ => panic!("operands should always be some here, we checked earlier"), - } - } - other => { - panic!("Unexpected branch: {:?}", other) - } - } - } - } -} - -pub fn does_fhe_operation_support_scalar(op: &SupportedFheOperations) -> bool { - match op.op_type() { - FheOperationType::Binary => true, - FheOperationType::Unary => false, - FheOperationType::Other => { - match op { - // second operand determines which type to cast to - SupportedFheOperations::FheCast => true, - _ => false, - } - } - } -} - -// add operations here that don't support both encrypted operands -#[cfg(test)] -pub fn does_fhe_operation_support_both_encrypted_operands(op: &SupportedFheOperations) -> bool { - match op { - SupportedFheOperations::FheDiv => false, - _ => true, - } -} +} \ No newline at end of file diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs new file mode 100644 index 00000000..1fcd0943 --- /dev/null +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -0,0 +1,353 @@ +use tfhe::integer::U256; +use tfhe::prelude::FheDecrypt; + +#[derive(Debug)] +pub enum FhevmError { + UnknownFheOperation(i32), + UnknownFheType(i32), + UnknownCiphertextType(i16), + DeserializationError(Box), + FheOperationOnlyOneOperandCanBeScalar { + fhe_operation: i32, + fhe_operation_name: String, + scalar_operand_count: usize, + max_scalar_operands: usize, + }, + FheOperationDoesntSupportScalar { + fhe_operation: i32, + fhe_operation_name: String, + scalar_requested: bool, + scalar_supported: bool, + }, + FheOperationOnlySecondOperandCanBeScalar { + scalar_input_index: usize, + only_allowed_scalar_input_index: usize, + }, + FheOperationDoesntHaveUniformTypesAsInput { + fhe_operation: i32, + fhe_operation_name: String, + operand_types: Vec, + }, + FheOperationScalarDivisionByZero { + lhs_handle: String, + rhs_value: String, + fhe_operation: i32, + fhe_operation_name: String, + }, + UnexpectedOperandCountForFheOperation { + fhe_operation: i32, + fhe_operation_name: String, + expected_operands: usize, + got_operands: usize, + }, + FheIfThenElseUnexpectedOperandTypes { + fhe_operation: i32, + fhe_operation_name: String, + first_operand_type: i16, + first_expected_operand_type: i16, + first_expected_operand_type_name: String, + }, + FheIfThenElseMismatchingSecondAndThirdOperatorTypes { + fhe_operation: i32, + fhe_operation_name: String, + second_operand_type: i16, + third_operand_type: i16, + }, + UnexpectedCastOperandTypes { + fhe_operation: i32, + fhe_operation_name: String, + expected_operator_combination: Vec, + got_operand_combination: Vec, + }, + UnexpectedCastOperandSizeForScalarOperand { + fhe_operation: i32, + fhe_operation_name: String, + expected_scalar_operand_bytes: usize, + got_bytes: usize, + }, +} + +impl std::error::Error for FhevmError {} + +impl std::fmt::Display for FhevmError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::UnknownFheOperation(op) => { + write!(f, "Unknown fhe operation: {}", op) + } + Self::UnknownFheType(op) => { + write!(f, "Unknown fhe type: {}", op) + } + Self::UnknownCiphertextType(the_type) => { + write!(f, "Unknown input ciphertext type: {}", the_type) + } + Self::DeserializationError(e) => { + write!(f, "error deserializing ciphertext: {:?}", e) + }, + Self::FheOperationDoesntSupportScalar { + fhe_operation, + fhe_operation_name, + .. + } => { + write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) doesn't support scalar computation") + } + Self::FheOperationDoesntHaveUniformTypesAsInput { + fhe_operation, + fhe_operation_name, + operand_types, + } => { + write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) expects uniform types as input, received: {:?}", operand_types) + } + Self::FheOperationScalarDivisionByZero { + lhs_handle, + rhs_value, + fhe_operation, + fhe_operation_name, + } => { + write!(f, "zero on the right side of scalar division, lhs handle: {lhs_handle}, rhs value: {rhs_value}, fhe operation: {fhe_operation} fhe operation name:{fhe_operation_name}") + } + Self::UnexpectedOperandCountForFheOperation { + fhe_operation, + fhe_operation_name, + expected_operands, + got_operands, + } => { + write!(f, "fhe operation number {fhe_operation} ({fhe_operation_name}) received unexpected operand count, expected: {expected_operands}, received: {got_operands}") + } + Self::FheOperationOnlySecondOperandCanBeScalar { + scalar_input_index, + only_allowed_scalar_input_index, + } => { + write!(f, "computation has scalar operand which is not the second operand, scalar input index: {scalar_input_index}, only allowed scalar input index: {only_allowed_scalar_input_index}") + } + Self::UnexpectedCastOperandTypes { + fhe_operation, + fhe_operation_name, + expected_operator_combination, + got_operand_combination, + } => { + write!(f, "unexpected operand types for cast, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, expected operand combination: {:?}, got operand combination: {:?}", expected_operator_combination, got_operand_combination) + } + Self::UnexpectedCastOperandSizeForScalarOperand { + fhe_operation, + fhe_operation_name, + expected_scalar_operand_bytes, + got_bytes, + } => { + write!(f, "unexpected operand size for cast, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, expected bytes: {}, got bytes: {}", expected_scalar_operand_bytes, got_bytes) + } + Self::FheIfThenElseUnexpectedOperandTypes { fhe_operation, fhe_operation_name, first_operand_type, first_expected_operand_type, .. } => { + write!(f, "fhe if then else first operand should always be FheBool, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, first operand type: {first_operand_type}, first operand expected type: {first_expected_operand_type}") + } + Self::FheIfThenElseMismatchingSecondAndThirdOperatorTypes { fhe_operation, fhe_operation_name, second_operand_type, third_operand_type } => { + write!(f, "fhe if then else second and third operand types don't match, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, second operand type: {second_operand_type}, third operand type: {third_operand_type}") + } + Self::FheOperationOnlyOneOperandCanBeScalar { fhe_operation, fhe_operation_name, scalar_operand_count, max_scalar_operands } => { + write!(f, "only one operand can be scalar, fhe operation: {fhe_operation}, fhe operation name: {fhe_operation_name}, second operand count: {scalar_operand_count}, max scalar operands: {max_scalar_operands}") + } + } + } +} + +pub enum SupportedFheCiphertexts { + FheBool(tfhe::FheBool), + FheUint8(tfhe::FheUint8), + FheUint16(tfhe::FheUint16), + FheUint32(tfhe::FheUint32), + FheUint64(tfhe::FheUint64), + Scalar(U256), +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::EnumIter)] +#[repr(i8)] +pub enum SupportedFheOperations { + FheAdd = 0, + FheSub = 1, + FheMul = 2, + FheDiv = 3, + FheRem = 4, + FheBitAnd = 5, + FheBitOr = 6, + FheBitXor = 7, + FheShl = 8, + FheShr = 9, + FheRotl = 10, + FheRotr = 11, + FheEq = 12, + FheNe = 13, + FheGe = 14, + FheGt = 15, + FheLe = 16, + FheLt = 17, + FheMin = 18, + FheMax = 19, + FheNeg = 20, + FheNot = 21, + FheCast = 30, + FheIfThenElse = 31, +} + +#[derive(PartialEq, Eq)] +pub enum FheOperationType { + Binary, + Unary, + Other, +} + +impl SupportedFheCiphertexts { + pub fn serialize(&self) -> (i16, Vec) { + let type_num = self.type_num(); + match self { + SupportedFheCiphertexts::FheBool(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::FheUint8(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::FheUint16(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::FheUint32(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::FheUint64(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::Scalar(_) => { + panic!("we should never need to serialize scalar") + } + } + } + + pub fn type_num(&self) -> i16 { + match self { + SupportedFheCiphertexts::FheBool(_) => 1, + SupportedFheCiphertexts::FheUint8(_) => 2, + SupportedFheCiphertexts::FheUint16(_) => 3, + SupportedFheCiphertexts::FheUint32(_) => 4, + SupportedFheCiphertexts::FheUint64(_) => 5, + SupportedFheCiphertexts::Scalar(_) => { + panic!("we should never need to serialize scalar") + } + } + } + + pub fn decrypt(&self, client_key: &tfhe::ClientKey) -> String { + match self { + SupportedFheCiphertexts::FheBool(v) => v.decrypt(client_key).to_string(), + SupportedFheCiphertexts::FheUint8(v) => { + FheDecrypt::::decrypt(v, client_key).to_string() + } + SupportedFheCiphertexts::FheUint16(v) => { + FheDecrypt::::decrypt(v, client_key).to_string() + } + SupportedFheCiphertexts::FheUint32(v) => { + FheDecrypt::::decrypt(v, client_key).to_string() + } + SupportedFheCiphertexts::FheUint64(v) => { + FheDecrypt::::decrypt(v, client_key).to_string() + } + SupportedFheCiphertexts::Scalar(v) => { + let (l, h) = v.to_low_high_u128(); + format!("{l}{h}") + } + } + } +} + +impl SupportedFheOperations { + pub fn op_type(&self) -> FheOperationType { + match self { + SupportedFheOperations::FheAdd + | SupportedFheOperations::FheSub + | SupportedFheOperations::FheMul + | SupportedFheOperations::FheDiv + | SupportedFheOperations::FheRem + | SupportedFheOperations::FheBitAnd + | SupportedFheOperations::FheBitOr + | SupportedFheOperations::FheBitXor + | SupportedFheOperations::FheShl + | SupportedFheOperations::FheShr + | SupportedFheOperations::FheRotl + | SupportedFheOperations::FheRotr + | SupportedFheOperations::FheEq + | SupportedFheOperations::FheNe + | SupportedFheOperations::FheGe + | SupportedFheOperations::FheGt + | SupportedFheOperations::FheLe + | SupportedFheOperations::FheLt + | SupportedFheOperations::FheMin + | SupportedFheOperations::FheMax => FheOperationType::Binary, + SupportedFheOperations::FheNot | SupportedFheOperations::FheNeg => { + FheOperationType::Unary + } + SupportedFheOperations::FheIfThenElse | SupportedFheOperations::FheCast => { + FheOperationType::Other + } + } + } + + pub fn is_comparison(&self) -> bool { + match self { + SupportedFheOperations::FheEq + | SupportedFheOperations::FheNe + | SupportedFheOperations::FheGe + | SupportedFheOperations::FheGt + | SupportedFheOperations::FheLe + | SupportedFheOperations::FheLt => true, + _ => false, + } + } +} + +impl TryFrom for SupportedFheOperations { + type Error = FhevmError; + + fn try_from(value: i16) -> Result { + let res = match value { + 0 => Ok(SupportedFheOperations::FheAdd), + 1 => Ok(SupportedFheOperations::FheSub), + 2 => Ok(SupportedFheOperations::FheMul), + 3 => Ok(SupportedFheOperations::FheDiv), + 4 => Ok(SupportedFheOperations::FheRem), + 5 => Ok(SupportedFheOperations::FheBitAnd), + 6 => Ok(SupportedFheOperations::FheBitOr), + 7 => Ok(SupportedFheOperations::FheBitXor), + 8 => Ok(SupportedFheOperations::FheShl), + 9 => Ok(SupportedFheOperations::FheShr), + 10 => Ok(SupportedFheOperations::FheRotl), + 11 => Ok(SupportedFheOperations::FheRotr), + 12 => Ok(SupportedFheOperations::FheEq), + 13 => Ok(SupportedFheOperations::FheNe), + 14 => Ok(SupportedFheOperations::FheGe), + 15 => Ok(SupportedFheOperations::FheGt), + 16 => Ok(SupportedFheOperations::FheLe), + 17 => Ok(SupportedFheOperations::FheLt), + 18 => Ok(SupportedFheOperations::FheMin), + 19 => Ok(SupportedFheOperations::FheMax), + 20 => Ok(SupportedFheOperations::FheNeg), + 21 => Ok(SupportedFheOperations::FheNot), + 30 => Ok(SupportedFheOperations::FheCast), + 31 => Ok(SupportedFheOperations::FheIfThenElse), + _ => Err(FhevmError::UnknownFheOperation(value as i32)), + }; + + // ensure we're always having the same value serialized back and forth + if let Ok(v) = &res { + assert_eq!(v.clone() as i16, value); + } + + res + } +} + +// we get i32 from protobuf (smaller types unsupported) +// but in database we store i16 +impl TryFrom for SupportedFheOperations { + type Error = FhevmError; + + fn try_from(value: i32) -> Result { + let initial_value: i16 = value + .try_into() + .map_err(|_| FhevmError::UnknownFheOperation(value))?; + + let final_value: Result = initial_value.try_into(); + final_value + } +} + +impl From for i16 { + fn from(value: SupportedFheOperations) -> Self { + value as i16 + } +} \ No newline at end of file