Skip to content

Commit

Permalink
Add support for 128bit/160bit and 256bit operations
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zk committed Sep 4, 2024
1 parent 658db39 commit 53525e6
Show file tree
Hide file tree
Showing 8 changed files with 882 additions and 263 deletions.
1 change: 1 addition & 0 deletions fhevm-engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions fhevm-engine/coprocessor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::num::NonZeroUsize;

use crate::db_queries::{check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db, fetch_tenant_server_key};
use crate::server::coprocessor::GenericResponse;
use bigdecimal::num_bigint::BigUint;
use fhevm_engine_common::tfhe_ops::{check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_be_bytes, deserialize_fhe_ciphertext, try_expand_ciphertext_list};
use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts};
use sha3::{Digest, Keccak256};
Expand Down Expand Up @@ -346,7 +347,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ

let mut public_key = sqlx::query!(
"
SELECT sks_key
SELECT sks_key, cks_key
FROM tenants
WHERE tenant_id = $1
",
Expand All @@ -360,6 +361,9 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ

let public_key = public_key.pop().unwrap();

// for checking if decryption equals to trivial encryption value
let client_key: tfhe::ClientKey = bincode::deserialize(&public_key.cks_key.unwrap()).unwrap();

let cloned = req.values.clone();
let out_cts = tokio::task::spawn_blocking(move || {
let server_key: tfhe::ServerKey = bincode::deserialize(&public_key.sks_key).unwrap();
Expand All @@ -368,7 +372,10 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
// single threaded implementation as this is debug function and it is simple to implement
let mut res: Vec<(Vec<u8>, i16, Vec<u8>)> = Vec::with_capacity(cloned.len());
for v in cloned {
let ct = debug_trivial_encrypt_be_bytes(v.output_type as i16, &v.le_value);
let the_num = BigUint::from_bytes_be(&v.be_value).to_string();
let ct = debug_trivial_encrypt_be_bytes(v.output_type as i16, &v.be_value);
let decr = ct.decrypt(&client_key);
assert_eq!(the_num, decr, "Trivial encryption must preserve the original value");
let (ct_type, ct_bytes) = ct.serialize();
res.push((v.handle, ct_type, ct_bytes));
}
Expand Down
6 changes: 3 additions & 3 deletions fhevm-engine/coprocessor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ async fn test_smoke() -> Result<(), Box<dyn std::error::Error>> {
values: vec![
DebugEncryptRequestSingle {
handle: vec![0x0a, 0xbc],
le_value: vec![123],
be_value: vec![123],
output_type: ct_type,
},
DebugEncryptRequestSingle {
handle: vec![0x0a, 0xbd],
le_value: vec![124],
be_value: vec![124],
output_type: ct_type,
},
],
Expand Down Expand Up @@ -110,4 +110,4 @@ async fn test_smoke() -> Result<(), Box<dyn std::error::Error>> {
}

Ok(())
}
}
106 changes: 94 additions & 12 deletions fhevm-engine/coprocessor/src/tests/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ struct UnaryOperatorTestCase {
}

fn supported_bits() -> &'static [i32] {
&[8, 16, 32, 64]
&[
8,
16,
32,
64,
128,
160,
256,
]
}

fn supported_types() -> &'static [i32] {
Expand All @@ -47,9 +55,9 @@ fn supported_types() -> &'static [i32] {
3, // 16 bit
4, // 32 bit
5, // 64 bit
// 6, TODO: add 128 bit support
// 7, TODO: add 160 bit support
// 8, TODO: add 256 bit support
6, // 128 bit
7, // 160 bit
8, // 256 bit
]
}

Expand All @@ -59,6 +67,9 @@ fn supported_bits_to_bit_type_in_db(inp: i32) -> i32 {
16 => 3,
32 => 4,
64 => 5,
128 => 6,
160 => 7,
256 => 8,
other => panic!("unknown supported bits: {other}"),
}
}
Expand Down Expand Up @@ -104,14 +115,14 @@ async fn test_fhe_binary_operands() -> Result<(), Box<dyn std::error::Error>> {
);
enc_request_payload.push(DebugEncryptRequestSingle {
handle: lhs_handle.clone(),
le_value: lhs_bytes,
be_value: lhs_bytes,
output_type: op.input_types,
});
if !op.is_scalar {
let (_, rhs_bytes) = op.rhs.to_bytes_be();
enc_request_payload.push(DebugEncryptRequestSingle {
handle: rhs_handle.clone(),
le_value: rhs_bytes,
be_value: rhs_bytes,
output_type: op.input_types,
});
}
Expand Down Expand Up @@ -234,7 +245,7 @@ async fn test_fhe_unary_operands() -> Result<(), Box<dyn std::error::Error>> {
);
enc_request_payload.push(DebugEncryptRequestSingle {
handle: input_handle.clone(),
le_value: inp_bytes,
be_value: inp_bytes,
output_type: op.operand_types,
});

Expand Down Expand Up @@ -359,7 +370,7 @@ async fn test_fhe_casts() -> Result<(), Box<dyn std::error::Error>> {
);
enc_request_payload.push(DebugEncryptRequestSingle {
handle: input_handle.clone(),
le_value: inp_bytes,
be_value: inp_bytes,
output_type: *type_from,
});
cast_outputs.push(CastOutput {
Expand Down Expand Up @@ -482,12 +493,12 @@ async fn test_fhe_if_then_else() -> Result<(), Box<dyn std::error::Error>> {
let true_handle = next_handle();
enc_request_payload.push(DebugEncryptRequestSingle {
handle: false_handle.clone(),
le_value: BigInt::from(0).to_bytes_be().1,
be_value: BigInt::from(0).to_bytes_be().1,
output_type: fhe_bool_type,
});
enc_request_payload.push(DebugEncryptRequestSingle {
handle: true_handle.clone(),
le_value: BigInt::from(1).to_bytes_be().1,
be_value: BigInt::from(1).to_bytes_be().1,
output_type: fhe_bool_type,
});

Expand All @@ -504,12 +515,12 @@ async fn test_fhe_if_then_else() -> Result<(), Box<dyn std::error::Error>> {
};
enc_request_payload.push(DebugEncryptRequestSingle {
handle: left_handle.clone(),
le_value: BigInt::from(left_input).to_bytes_be().1,
be_value: BigInt::from(left_input).to_bytes_be().1,
output_type: *input_types,
});
enc_request_payload.push(DebugEncryptRequestSingle {
handle: right_handle.clone(),
le_value: BigInt::from(right_input).to_bytes_be().1,
be_value: BigInt::from(right_input).to_bytes_be().1,
output_type: *input_types,
});

Expand Down Expand Up @@ -718,6 +729,17 @@ fn compute_expected_unary_output(inp: &BigInt, op: SupportedFheOperations, bits:
let inp: u64 = inp.try_into().unwrap();
BigInt::from(inp.not())
}
128 => {
let inp: u128 = inp.try_into().unwrap();
BigInt::from(inp.not())
}
160 | 256 => {
let (_, mut bytes) = inp.to_bytes_be();
for byte in bytes.iter_mut() {
*byte = byte.not();
}
BigInt::from(inp.not())
}
other => {
panic!("unknown bits: {other}")
}
Expand All @@ -740,6 +762,13 @@ fn compute_expected_unary_output(inp: &BigInt, op: SupportedFheOperations, bits:
let inp: i64 = inp.try_into().unwrap();
BigInt::from(-inp as u64)
}
128 => {
let inp: i128 = inp.try_into().unwrap();
BigInt::from(-inp as u128)
}
160 | 256 => {
inp * -1
}
other => {
panic!("unknown bits: {other}")
}
Expand All @@ -748,6 +777,41 @@ fn compute_expected_unary_output(inp: &BigInt, op: SupportedFheOperations, bits:
}
}

fn rotate_left_big_int(inp: &BigInt, rot_by: u32) -> BigInt {
let mut new_num = inp.clone();
let mut idx_vec = Vec::new();
for bit in 0..inp.bits() {
idx_vec.push(bit);
}
idx_vec.rotate_left(rot_by as usize);
for bit in 0..inp.bits() {
new_num.set_bit(idx_vec[bit as usize], inp.bit(bit));
}
new_num
}

fn rotate_right_big_int(inp: &BigInt, rot_by: u32) -> BigInt {
let mut new_num = inp.clone();
let mut idx_vec = Vec::new();
for bit in 0..inp.bits() {
idx_vec.push(bit);
}
idx_vec.rotate_right(rot_by as usize);
for bit in 0..inp.bits() {
new_num.set_bit(idx_vec[bit as usize], inp.bit(bit));
}
new_num
}

#[test]
fn big_int_rotation() {
let the_int = BigInt::from(22);
let left_rot: u8 = rotate_left_big_int(&the_int, 1).try_into().unwrap();
let right_rot: u8 = rotate_right_big_int(&the_int, 1).try_into().unwrap();
assert_eq!(left_rot, 13);
assert_eq!(right_rot, 11);
}

fn compute_expected_binary_output(
lhs: &BigInt,
rhs: &BigInt,
Expand Down Expand Up @@ -794,6 +858,15 @@ fn compute_expected_binary_output(
.unwrap()
.rotate_left(TryInto::<u32>::try_into(rhs).unwrap()),
),
128 => BigInt::from(
TryInto::<u128>::try_into(lhs)
.unwrap()
.rotate_left(TryInto::<u32>::try_into(rhs).unwrap()),
),
160 | 256 => {
let rot_by = TryInto::<u32>::try_into(rhs).unwrap();
rotate_left_big_int(lhs, rot_by)
}
other => {
panic!("unsupported bits for rotl: {other}")
}
Expand All @@ -819,6 +892,15 @@ fn compute_expected_binary_output(
.unwrap()
.rotate_right(TryInto::<u32>::try_into(rhs).unwrap()),
),
128 => BigInt::from(
TryInto::<u128>::try_into(lhs)
.unwrap()
.rotate_left(TryInto::<u32>::try_into(rhs).unwrap()),
),
160 | 256 => {
let rot_by = TryInto::<u32>::try_into(rhs).unwrap();
rotate_right_big_int(lhs, rot_by)
}
other => {
panic!("unsupported bits for rotr: {other}")
}
Expand Down
1 change: 1 addition & 0 deletions fhevm-engine/fhevm-engine-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ anyhow.workspace = true
strum = { version = "0.26", features = ["derive"] }
bincode = "1.3.3"
hex = "0.4"
bigdecimal = "0.4.5"

[[bin]]
name = "generate-keys"
Expand Down
Loading

0 comments on commit 53525e6

Please sign in to comment.