diff --git a/fhevm-engine/coprocessor/src/tests/operators.rs b/fhevm-engine/coprocessor/src/tests/operators.rs index 31d764c4..da69d4dd 100644 --- a/fhevm-engine/coprocessor/src/tests/operators.rs +++ b/fhevm-engine/coprocessor/src/tests/operators.rs @@ -39,13 +39,13 @@ struct UnaryOperatorTestCase { } fn supported_bits() -> &'static [i32] { - &[8, 16, 32, 64, 128, 160, 256, 512, 1024, 2048] + &[4, 8, 16, 32, 64, 128, 160, 256, 512, 1024, 2048] } pub fn supported_types() -> &'static [i32] { &[ - 0, // bool - // 1, TODO: add 4 bit support + 0, // bool + 1, // 4 bit 2, // 8 bit 3, // 16 bit 4, // 32 bit @@ -61,6 +61,7 @@ pub fn supported_types() -> &'static [i32] { fn supported_bits_to_bit_type_in_db(inp: i32) -> i32 { match inp { + 4 => 1, 8 => 2, 16 => 3, 32 => 4, @@ -620,12 +621,12 @@ fn generate_binary_test_cases() -> Vec { SupportedFheOperations::FheRotr, ]; let mut push_case = |bits: i32, is_scalar: bool, shift_by: i32, op: SupportedFheOperations| { - let mut lhs = BigInt::from(12); - let mut rhs = BigInt::from(7); + let mut lhs = BigInt::from(6); + let mut rhs = BigInt::from(2); lhs <<= shift_by; // don't shift by much for bit shift opts not to make result 0 if bit_shift_ops.contains(&op) { - rhs = BigInt::from(2); + rhs = BigInt::from(1); } else { rhs <<= shift_by; } @@ -651,7 +652,8 @@ fn generate_binary_test_cases() -> Vec { for bits in supported_bits() { let bits = *bits; - let mut shift_by = bits - 8; + let mut shift_by = + if bits > 4 { bits - 8 } else { 0 }; for op in SupportedFheOperations::iter() { if bits <= 256 || op.supports_ebytes_inputs() { if op == SupportedFheOperations::FheMul { @@ -679,12 +681,13 @@ fn generate_unary_test_cases() -> Vec { for bits in supported_bits() { let bits = *bits; - let shift_by = bits - 8; + let shift_by = bits - 3; + let max_bits_value = (BigInt::from(1) << bits) - 1; for op in SupportedFheOperations::iter() { if op.op_type() == FheOperationType::Unary { - let mut inp = BigInt::from(7); + let mut inp = BigInt::from(3); inp <<= shift_by; - let expected_output = compute_expected_unary_output(&inp, op); + let expected_output = compute_expected_unary_output(&inp, op) & &max_bits_value; let operand = op as i32; cases.push(UnaryOperatorTestCase { bits, diff --git a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs index c45bd1b1..69b8e6a6 100644 --- a/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs +++ b/fhevm-engine/fhevm-engine-common/src/tfhe_ops.rs @@ -2,7 +2,7 @@ use crate::types::{is_ebytes_type, FheOperationType, FhevmError, SupportedFheCip use tfhe::{ integer::{bigint::StaticUnsignedBigInt, U256}, prelude::{ CastInto, FheEq, FheMax, FheMin, FheOrd, FheTryTrivialEncrypt, IfThenElse, RotateLeft, RotateRight - }, FheBool, FheUint1024, FheUint128, FheUint16, FheUint160, FheUint2048, FheUint256, FheUint32, FheUint512, FheUint64, FheUint8, Seed + }, FheBool, FheUint1024, FheUint128, FheUint16, FheUint160, FheUint2048, FheUint256, FheUint32, FheUint4, FheUint512, FheUint64, FheUint8, Seed }; pub fn deserialize_fhe_ciphertext( @@ -15,6 +15,11 @@ pub fn deserialize_fhe_ciphertext( .map_err(|e| FhevmError::DeserializationError(e))?; Ok(SupportedFheCiphertexts::FheBool(v)) } + 1 => { + let v: tfhe::FheUint4 = bincode::deserialize(input_bytes) + .map_err(|e| FhevmError::DeserializationError(e))?; + Ok(SupportedFheCiphertexts::FheUint4(v)) + } 2 => { let v: tfhe::FheUint8 = bincode::deserialize(input_bytes) .map_err(|e| FhevmError::DeserializationError(e))?; @@ -80,6 +85,9 @@ pub fn trivial_encrypt_be_bytes( 0 => SupportedFheCiphertexts::FheBool( FheBool::try_encrypt_trivial(input_bytes[0] > 0).unwrap(), ), + 1 => SupportedFheCiphertexts::FheUint4( + FheUint4::try_encrypt_trivial(input_bytes[0]).unwrap(), + ), 2 => SupportedFheCiphertexts::FheUint8( FheUint8::try_encrypt_trivial(input_bytes[0]).unwrap(), ), @@ -261,6 +269,14 @@ pub fn try_expand_ciphertext_list( res.push(SupportedFheCiphertexts::FheBool(ct)); } + tfhe::FheTypes::Uint4 => { + let ct: tfhe::FheUint4 = expanded + .get(idx) + .expect("Index must exist") + .expect("Must succeed, we just checked this is the type"); + + res.push(SupportedFheCiphertexts::FheUint4(ct)); + } tfhe::FheTypes::Uint8 => { let ct: tfhe::FheUint8 = expanded .get(idx) @@ -702,7 +718,7 @@ pub fn validate_fhe_type(input_type: i32) -> Result<(), FhevmError> { .try_into() .or(Err(FhevmError::UnknownFheType(input_type)))?; match i16_type { - 0 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 => Ok(()), + 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 => Ok(()), _ => Err(FhevmError::UnknownFheType(input_type)), } } @@ -753,6 +769,9 @@ pub fn perform_fhe_operation( // fhe add match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a + b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a + b)) } @@ -783,6 +802,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a + b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a + (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a + (l as u8))) @@ -819,6 +842,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a - b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a - b)) } @@ -849,6 +875,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a - b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a - (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a - (l as u8))) @@ -885,6 +915,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a * b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a * b)) } @@ -915,6 +948,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a * b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a * (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a * (l as u8))) @@ -951,6 +988,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a / b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a / b)) } @@ -981,6 +1021,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a / b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a / (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a / (l as u8))) @@ -1017,6 +1061,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a % b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a % b)) } @@ -1047,6 +1094,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a % b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a % (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a % (l as u8))) @@ -1083,6 +1134,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a & b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a & b)) } @@ -1113,6 +1167,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes64(a), SupportedFheCiphertexts::FheBytes64(b)) => { Ok(SupportedFheCiphertexts::FheBytes64(a & b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a & (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a & (l as u8))) @@ -1149,6 +1207,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a | b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a | b)) } @@ -1179,6 +1240,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a | b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a | (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a | (l as u8))) @@ -1215,6 +1280,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a ^ b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a ^ b)) } @@ -1245,6 +1313,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a ^ b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a ^ (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a ^ (l as u8))) @@ -1281,6 +1353,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a << b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a << b)) } @@ -1311,6 +1386,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a << b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a << (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a << (l as u8))) @@ -1347,6 +1426,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a >> b)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a >> b)) } @@ -1377,6 +1459,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a >> b)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a >> (l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a >> (l as u8))) @@ -1413,6 +1499,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a.rotate_left(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a.rotate_left(b))) } @@ -1443,6 +1532,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a.rotate_left(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a.rotate_left(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a.rotate_left(l as u8))) @@ -1479,6 +1572,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a.rotate_right(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a.rotate_right(b))) } @@ -1509,6 +1605,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a.rotate_right(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a.rotate_right(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a.rotate_right(l as u8))) @@ -1545,6 +1645,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a.min(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a.min(b))) } @@ -1575,6 +1678,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a.min(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a.min(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a.min(l as u8))) @@ -1611,6 +1718,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheUint4(a.max(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a.max(b))) } @@ -1641,6 +1751,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBytes256(a.max(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheUint4(a.max(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheUint8(a.max(l as u8))) @@ -1680,6 +1794,9 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBool(a), SupportedFheCiphertexts::FheBool(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.eq(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.eq(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.eq(b))) } @@ -1715,6 +1832,10 @@ pub fn perform_fhe_operation( let non_zero = l > 0 || h > 0; Ok(SupportedFheCiphertexts::FheBool(a.eq(non_zero))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.eq(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.eq(l as u8))) @@ -1754,6 +1875,9 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBool(a), SupportedFheCiphertexts::FheBool(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.ne(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.ne(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.ne(b))) } @@ -1789,6 +1913,10 @@ pub fn perform_fhe_operation( let non_zero = l > 0 || h > 0; Ok(SupportedFheCiphertexts::FheBool(a.ne(non_zero))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.ne(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.ne(l as u8))) @@ -1825,6 +1953,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.ge(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.ge(b))) } @@ -1855,6 +1986,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.ge(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.ge(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.ge(l as u8))) @@ -1891,6 +2026,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.gt(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.gt(b))) } @@ -1921,6 +2059,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.gt(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.gt(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.gt(l as u8))) @@ -1957,6 +2099,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.le(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.le(b))) } @@ -1987,6 +2132,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.le(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.le(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.le(l as u8))) @@ -2023,6 +2172,9 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 2); match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + Ok(SupportedFheCiphertexts::FheBool(a.lt(b))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.lt(b))) } @@ -2053,6 +2205,10 @@ pub fn perform_fhe_operation( (SupportedFheCiphertexts::FheBytes256(a), SupportedFheCiphertexts::FheBytes256(b)) => { Ok(SupportedFheCiphertexts::FheBool(a.lt(b))) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::Scalar(b)) => { + let (l, _) = b.to_low_high_u128(); + Ok(SupportedFheCiphertexts::FheBool(a.lt(l as u8))) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { let (l, _) = b.to_low_high_u128(); Ok(SupportedFheCiphertexts::FheBool(a.lt(l as u8))) @@ -2090,6 +2246,7 @@ pub fn perform_fhe_operation( match &input_operands[0] { SupportedFheCiphertexts::FheBool(a) => Ok(SupportedFheCiphertexts::FheBool(!a)), + SupportedFheCiphertexts::FheUint4(a) => Ok(SupportedFheCiphertexts::FheUint4(!a)), SupportedFheCiphertexts::FheUint8(a) => Ok(SupportedFheCiphertexts::FheUint8(!a)), SupportedFheCiphertexts::FheUint16(a) => Ok(SupportedFheCiphertexts::FheUint16(!a)), SupportedFheCiphertexts::FheUint32(a) => Ok(SupportedFheCiphertexts::FheUint32(!a)), @@ -2109,6 +2266,7 @@ pub fn perform_fhe_operation( assert_eq!(input_operands.len(), 1); match &input_operands[0] { + SupportedFheCiphertexts::FheUint4(a) => Ok(SupportedFheCiphertexts::FheUint4(-a)), SupportedFheCiphertexts::FheUint8(a) => Ok(SupportedFheCiphertexts::FheUint8(-a)), SupportedFheCiphertexts::FheUint16(a) => Ok(SupportedFheCiphertexts::FheUint16(-a)), SupportedFheCiphertexts::FheUint32(a) => Ok(SupportedFheCiphertexts::FheUint32(-a)), @@ -2136,6 +2294,10 @@ pub fn perform_fhe_operation( let res = flag.select(a, b); Ok(SupportedFheCiphertexts::FheBool(res)) } + (SupportedFheCiphertexts::FheUint4(a), SupportedFheCiphertexts::FheUint4(b)) => { + let res = flag.select(a, b); + Ok(SupportedFheCiphertexts::FheUint4(res)) + } (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { let res = flag.select(a, b); Ok(SupportedFheCiphertexts::FheUint8(res)) @@ -2190,6 +2352,66 @@ pub fn perform_fhe_operation( return Ok(SupportedFheCiphertexts::FheBool(inp.clone())); } else { match l { + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } + 2 => { + let out: tfhe::FheUint8 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint8(out)) + } + 3 => { + let out: tfhe::FheUint16 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint16(out)) + } + 4 => { + let out: tfhe::FheUint32 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint32(out)) + } + 5 => { + let out: tfhe::FheUint64 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint64(out)) + } + 6 => { + let out: tfhe::FheUint128 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint128(out)) + } + 7 => { + let out: tfhe::FheUint160 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint160(out)) + } + 8 => { + let out: tfhe::FheUint256 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint256(out)) + } + 9 => { + let out: tfhe::FheUint512 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheBytes64(out)) + } + 10 => { + let out: tfhe::FheUint1024 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheBytes128(out)) + } + 11 => { + let out: tfhe::FheUint2048 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheBytes256(out)) + } + other => panic!("unexpected type: {other}"), + } + } + } + (SupportedFheCiphertexts::FheUint4(inp), SupportedFheCiphertexts::Scalar(op)) => { + let (l, _) = op.to_low_high_u128(); + let l = l as i16; + let type_id = input_operands[0].type_num(); + if l == type_id { + return Ok(SupportedFheCiphertexts::FheUint4(inp.clone())); + } else { + match l { + 0 => { + let out: tfhe::FheBool = inp.gt(0); + Ok(SupportedFheCiphertexts::FheBool(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2246,6 +2468,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 3 => { let out: tfhe::FheUint16 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint16(out)) @@ -2298,6 +2524,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2350,6 +2580,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2402,6 +2636,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2454,6 +2692,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2506,6 +2748,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2558,6 +2804,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2610,6 +2860,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2662,6 +2916,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2714,6 +2972,10 @@ pub fn perform_fhe_operation( let out: tfhe::FheBool = inp.gt(0); Ok(SupportedFheCiphertexts::FheBool(out)) } + 1 => { + let out: tfhe::FheUint4 = inp.clone().cast_into(); + Ok(SupportedFheCiphertexts::FheUint4(out)) + } 2 => { let out: tfhe::FheUint8 = inp.clone().cast_into(); Ok(SupportedFheCiphertexts::FheUint8(out)) @@ -2794,6 +3056,12 @@ pub fn generate_random_number(the_type: i16, seed: u128, upper_bound: Option { + let bit_count = 4; + let random_bits = upper_bound.map(|i| subtract_from - i.leading_zeros()) + .unwrap_or(bit_count).min(bit_count) as u64; + SupportedFheCiphertexts::FheUint4(FheUint4::generate_oblivious_pseudo_random(Seed(seed), random_bits)) + }, 2 => { let bit_count = 8; let random_bits = upper_bound.map(|i| subtract_from - i.leading_zeros()) diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index 28ed2b47..2f6637ca 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -258,6 +258,7 @@ impl std::fmt::Display for FhevmError { #[derive(Clone)] pub enum SupportedFheCiphertexts { FheBool(tfhe::FheBool), + FheUint4(tfhe::FheUint4), FheUint8(tfhe::FheUint8), FheUint16(tfhe::FheUint16), FheUint32(tfhe::FheUint32), @@ -315,6 +316,7 @@ impl SupportedFheCiphertexts { let type_num = self.type_num(); match self { SupportedFheCiphertexts::FheBool(v) => (type_num, bincode::serialize(v).unwrap()), + SupportedFheCiphertexts::FheUint4(v) => (type_num, bincode::serialize(v).unwrap()), SupportedFheCiphertexts::FheUint8(v) => (type_num, bincode::serialize(v).unwrap()), SupportedFheCiphertexts::FheUint16(v) => (type_num, bincode::serialize(v).unwrap()), SupportedFheCiphertexts::FheUint32(v) => (type_num, bincode::serialize(v).unwrap()), @@ -335,7 +337,7 @@ impl SupportedFheCiphertexts { match self { // values taken to match with solidity library SupportedFheCiphertexts::FheBool(_) => 0, - // TODO: add FheUint4 support + SupportedFheCiphertexts::FheUint4(_) => 1, SupportedFheCiphertexts::FheUint8(_) => 2, SupportedFheCiphertexts::FheUint16(_) => 3, SupportedFheCiphertexts::FheUint32(_) => 4, @@ -355,6 +357,9 @@ impl SupportedFheCiphertexts { pub fn decrypt(&self, client_key: &tfhe::ClientKey) -> String { match self { SupportedFheCiphertexts::FheBool(v) => v.decrypt(client_key).to_string(), + SupportedFheCiphertexts::FheUint4(v) => { + FheDecrypt::::decrypt(v, client_key).to_string() + } SupportedFheCiphertexts::FheUint8(v) => { FheDecrypt::::decrypt(v, client_key).to_string() } @@ -412,6 +417,7 @@ impl SupportedFheCiphertexts { let mut builder = CompressedCiphertextListBuilder::new(); match self { SupportedFheCiphertexts::FheBool(c) => builder.push(c), + SupportedFheCiphertexts::FheUint4(c) => builder.push(c), SupportedFheCiphertexts::FheUint8(c) => builder.push(c), SupportedFheCiphertexts::FheUint16(c) => builder.push(c), SupportedFheCiphertexts::FheUint32(c) => builder.push(c), @@ -434,7 +440,10 @@ impl SupportedFheCiphertexts { pub fn decompress(ct_type: i16, list: &[u8]) -> Result { let list: CompressedCiphertextList = bincode::deserialize(list)?; match ct_type { - 1 => Ok(SupportedFheCiphertexts::FheBool( + 0 => Ok(SupportedFheCiphertexts::FheBool( + list.get(0)?.ok_or(FhevmError::MissingTfheRsData)?, + )), + 1 => Ok(SupportedFheCiphertexts::FheUint4( list.get(0)?.ok_or(FhevmError::MissingTfheRsData)?, )), 2 => Ok(SupportedFheCiphertexts::FheUint8( @@ -477,6 +486,7 @@ impl SupportedFheCiphertexts { | SupportedFheCiphertexts::FheBytes128(_) | SupportedFheCiphertexts::FheBytes256(_) => true, SupportedFheCiphertexts::FheBool(_) + | SupportedFheCiphertexts::FheUint4(_) | SupportedFheCiphertexts::FheUint8(_) | SupportedFheCiphertexts::FheUint16(_) | SupportedFheCiphertexts::FheUint32(_)