Skip to content

Commit

Permalink
fix(cpu): fix corner case when estimating the num blocks required
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Feb 26, 2025
1 parent ec3f3a1 commit 22377bb
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 62 deletions.
29 changes: 2 additions & 27 deletions tfhe/src/integer/gpu/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::core_crypto::prelude::{
par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key, LweBootstrapKeyOwned,
LweMultiBitBootstrapKeyOwned,
};
use crate::integer::gpu::UnsignedInteger;
use crate::integer::ClientKey;
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::engine::ShortintEngine;
Expand Down Expand Up @@ -252,31 +251,7 @@ impl CudaServerKey {
}
}

#[allow(clippy::unused_self)]
pub(crate) fn num_bits_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
if clear == Clear::MAX {
Clear::BITS
} else {
let bits = (clear + Clear::ONE).ceil_ilog2() as usize;
if bits == 0 {
1
} else {
bits
}
}
}

/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub(crate) fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = self.num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus.0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
pub fn message_modulus(&self) -> MessageModulus {
self.message_modulus
}
}
7 changes: 4 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/vector_find.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextIn
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::radix::CudaRadixCiphertext;
use crate::integer::gpu::server_key::CudaServerKey;
use crate::integer::server_key::num_blocks_to_represent_unsigned_value;
pub use crate::integer::server_key::radix_parallel::MatchValues;
use crate::prelude::CastInto;
use itertools::Itertools;
Expand Down Expand Up @@ -147,7 +148,7 @@ impl CudaServerKey {
.1;

let num_blocks_to_represent_values =
self.num_blocks_to_represent_unsigned_value(max_output_value);
num_blocks_to_represent_unsigned_value(max_output_value, self.message_modulus);

let blocks_ct = self.convert_selectors_to_unsigned_radix_ciphertext(&selectors, streams);

Expand Down Expand Up @@ -277,7 +278,7 @@ impl CudaServerKey {
if matches.get_values().is_empty() {
let ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(
or_value,
self.num_blocks_to_represent_unsigned_value(or_value),
num_blocks_to_represent_unsigned_value(or_value, self.message_modulus),
streams,
);
return ct;
Expand All @@ -287,7 +288,7 @@ impl CudaServerKey {
// The result must have as many block to represent either the result of the match or the
// or_value
let num_blocks_to_represent_or_value =
self.num_blocks_to_represent_unsigned_value(or_value);
num_blocks_to_represent_unsigned_value(or_value, self.message_modulus);
let num_blocks = (result.as_ref().d_blocks.lwe_ciphertext_count().0)
.max(num_blocks_to_represent_or_value);
let or_value: CudaUnsignedRadixCiphertext =
Expand Down
52 changes: 30 additions & 22 deletions tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,6 @@ impl ServerKey {
self.key.carry_modulus
}

pub fn num_bits_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
if clear == Clear::MAX {
Clear::BITS
} else {
(clear + Clear::ONE).ceil_ilog2() as usize
}
}

/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub fn num_blocks_to_represent_unsigned_value<Clear>(&self, clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = self.num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = self.message_modulus().0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}

/// Returns how many ciphertext can be summed at once
///
/// The number of ciphertext that can be added together depends on the degree
Expand All @@ -250,6 +228,36 @@ impl ServerKey {
}
}

pub fn num_bits_to_represent_unsigned_value<Clear>(clear: Clear) -> usize
where
Clear: UnsignedInteger,
{
if clear == Clear::MAX {
Clear::BITS
} else {
let bits = (clear + Clear::ONE).ceil_ilog2() as usize;
if bits == 0 {
1
} else {
bits
}
}
}

/// Returns how many blocks a radix ciphertext should have to
/// be able to represent the given unsigned integer
pub fn num_blocks_to_represent_unsigned_value<Clear>(
clear: Clear,
message_modulus: MessageModulus,
) -> usize
where
Clear: UnsignedInteger,
{
let num_bits_to_represent_output_value = num_bits_to_represent_unsigned_value(clear);
let num_bits_in_message = message_modulus.0.ilog2();
num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}

impl AsRef<crate::shortint::ServerKey> for ServerKey {
fn as_ref(&self) -> &crate::shortint::ServerKey {
&self.key
Expand Down
11 changes: 7 additions & 4 deletions tfhe/src/integer/server_key/radix_parallel/count_zeros_ones.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::ServerKey;
use crate::integer::server_key::{
num_bits_to_represent_unsigned_value, num_blocks_to_represent_unsigned_value,
};
use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, SignedRadixCiphertext};
use crate::shortint::ciphertext::Degree;

Expand Down Expand Up @@ -192,7 +195,7 @@ impl ServerKey {
.checked_mul(ct.blocks().len() as u32)
.expect("Number of bits exceed u32::MAX");
let num_unsigned_blocks =
self.num_blocks_to_represent_unsigned_value(max_possible_bit_count);
num_blocks_to_represent_unsigned_value(max_possible_bit_count, self.message_modulus());
if count_kind == BitCountKind::One {
let things_to_sum = pre_count
.into_iter()
Expand Down Expand Up @@ -221,8 +224,7 @@ impl ServerKey {
// But in the case of 1_X parameters, counting ones does not require to have
// a LUT done on each block to count the number of ones, and to avoid having to do a
// LUT to count zeros we prefer to change a bit the sum
let num_bits_needed =
self.num_bits_to_represent_unsigned_value(max_possible_bit_count) + 1;
let num_bits_needed = num_bits_to_represent_unsigned_value(max_possible_bit_count) + 1;
let num_signed_blocks = num_bits_needed.div_ceil(num_bits_in_block as usize);
assert!(num_signed_blocks >= num_unsigned_blocks);

Expand Down Expand Up @@ -523,7 +525,8 @@ impl ServerKey {
let max_possible_bit_count = num_bits_in_block
.checked_mul(ct.blocks().len() as u32)
.expect("Number of bits exceed u32::MAX");
let num_blocks = self.num_blocks_to_represent_unsigned_value(max_possible_bit_count);
let num_blocks =
num_blocks_to_represent_unsigned_value(max_possible_bit_count, self.message_modulus());

let things_to_sum = pre_count
.into_iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
nb_tests_for_params, random_non_zero_value, unsigned_modulus, CpuFunctionExecutor,
ExpectedDegrees, ExpectedNoiseLevels, MAX_VEC_LEN, NB_CTXT,
};
use crate::integer::server_key::MatchValues;
use crate::integer::server_key::{num_blocks_to_represent_unsigned_value, MatchValues};
use crate::integer::tests::create_parameterized_test;
use rand::prelude::*;

Expand Down Expand Up @@ -570,7 +570,8 @@ where
sks.create_trivial_radix(rng.gen_range(0..modulus), NB_CTXT),
];
let default_value = rng.gen_range(0..modulus);
let expected_len = sks.num_blocks_to_represent_unsigned_value(default_value);
let expected_len =
num_blocks_to_represent_unsigned_value(default_value, sks.message_modulus());

for ct in inputs {
let result = executor.execute((&ct, &empty_lut, default_value));
Expand Down Expand Up @@ -606,7 +607,7 @@ where

assert_eq!(
result.blocks.len(),
sks.num_blocks_to_represent_unsigned_value(u64::MAX)
num_blocks_to_represent_unsigned_value(u64::MAX, sks.message_modulus())
);

assert_eq!(cks.decrypt::<u64>(&result), u64::MAX);
Expand Down
7 changes: 4 additions & 3 deletions tfhe/src/integer/server_key/radix_parallel/vector_find.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::core_crypto::prelude::UnsignedInteger;
use crate::integer::block_decomposition::{BlockDecomposer, Decomposable, DecomposableInto};
use crate::integer::server_key::num_blocks_to_represent_unsigned_value;
use crate::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext, ServerKey};
use crate::prelude::CastInto;
use crate::shortint::Ciphertext;
Expand Down Expand Up @@ -94,7 +95,7 @@ impl ServerKey {
.1;

let num_blocks_to_represent_values =
self.num_blocks_to_represent_unsigned_value(max_output_value);
num_blocks_to_represent_unsigned_value(max_output_value, self.message_modulus());

let possible_results_to_be_aggregated = self.create_possible_results(
num_blocks_to_represent_values,
Expand Down Expand Up @@ -203,15 +204,15 @@ impl ServerKey {
if matches.0.is_empty() {
return self.create_trivial_radix(
or_value,
self.num_blocks_to_represent_unsigned_value(or_value),
num_blocks_to_represent_unsigned_value(or_value, self.message_modulus()),
);
}
let (result, selected) = self.unchecked_match_value_parallelized(ct, matches);

// The result must have as many block to represent either the result of the match or the
// or_value
let num_blocks_to_represent_or_value =
self.num_blocks_to_represent_unsigned_value(or_value);
num_blocks_to_represent_unsigned_value(or_value, self.message_modulus());
let num_blocks = result.blocks.len().max(num_blocks_to_represent_or_value);
let or_value = self.create_trivial_radix(or_value, num_blocks);
let result = self.cast_to_unsigned(result, num_blocks);
Expand Down

0 comments on commit 22377bb

Please sign in to comment.