Skip to content

Commit

Permalink
fix: bug regarding MixedRadix coset (I)NTT for NM/MN ordering (ingony…
Browse files Browse the repository at this point in the history
…ama-zk#497)

The bug is in how twiddles array is indexed when multiplied by a mixed
(M) vector to implement (I)NTT on cosets.
The fix is to use the DIF-digit-reverse to compute the index of the element in the
natural (N) vector that moved to index 'i' in the M vector. This is
emulating a DIT-digit-reverse (which is mixing like a DIF-compute)
reorder of the twiddles array and element-wise multiplication without
reordering the twiddles memory.
  • Loading branch information
yshekel authored Apr 25, 2024
1 parent f8d15e2 commit 36e288c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 11 deletions.
1 change: 1 addition & 0 deletions icicle/include/ntt/ntt_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace mxntt {
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode
int ntt_size,
int max_logn,
int batch_size,
Expand Down
28 changes: 21 additions & 7 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,23 @@ namespace mxntt {
int n_scalars,
uint32_t log_size,
eRevType rev_type,
bool dit,
bool fast_tw,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= size * batch_size) return;
int64_t scalar_id = (tid / columns_batch_size) % size;
if (rev_type != eRevType::None)
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
if (rev_type != eRevType::None) {
// Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the
// scalars the same way (then multiply element-wise). This would be a DIT-digit-reverse shuffle. (this is
// confusing but) BUT to avoid shuffling the scalars, we instead want to ask which element in the non-shuffled
// vec is now placed at index tid, which is the opposite of a DIT-digit-reverse --> this is the DIF-digit-reverse.
// Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the
// corresponding element in scalars vec.
const bool dif = rev_type == eRevType::NaturalToMixedRev;
scalar_id =
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type);
}
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}

Expand Down Expand Up @@ -903,6 +912,7 @@ namespace mxntt {
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode
int ntt_size,
int max_logn,
int batch_size,
Expand Down Expand Up @@ -958,8 +968,8 @@ namespace mxntt {
if (is_on_coset && !is_inverse) {
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
reverse_coset, dit, d_output);
arbitrary_coset ? arbitrary_coset : linear_twiddle, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
reverse_coset, fast_tw, d_output);

d_input = d_output;
}
Expand Down Expand Up @@ -991,8 +1001,8 @@ namespace mxntt {
if (is_on_coset && is_inverse) {
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
n_twiddles, logn, reverse_coset, dit, d_output);
arbitrary_coset ? arbitrary_coset : linear_twiddle + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
n_twiddles, logn, reverse_coset, fast_tw, d_output);
}

return CHK_LAST();
Expand Down Expand Up @@ -1021,6 +1031,8 @@ namespace mxntt {
scalar_t* external_twiddles,
scalar_t* internal_twiddles,
scalar_t* basic_twiddles,
scalar_t* linear_twiddles,

int ntt_size,
int max_logn,
int batch_size,
Expand All @@ -1039,6 +1051,8 @@ namespace mxntt {
scalar_t* external_twiddles,
scalar_t* internal_twiddles,
scalar_t* basic_twiddles,
scalar_t* linear_twiddles,

int ntt_size,
int max_logn,
int batch_size,
Expand Down
9 changes: 5 additions & 4 deletions icicle/src/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,7 @@ namespace ntt {
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
} else {
const bool is_on_coset = (coset_index != 0) || coset;
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr);
S* twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
: domain.twiddles;
Expand All @@ -728,9 +727,11 @@ namespace ntt {
S* basic_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
: domain.basic_twiddles;
S* linear_twiddles = domain.twiddles; // twiddles organized as [1,w,w^2,...]
CHK_IF_RETURN(mxntt::mixed_radix_ntt(
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, linear_twiddles, size, domain.max_log_size,
batch_size, config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index,
stream));
}
}

Expand Down
7 changes: 7 additions & 0 deletions wrappers/rust/icicle-core/src/ntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ macro_rules! impl_ntt_tests {
check_ntt_coset_from_subgroup::<$field>()
}

#[test]
#[parallel]
fn test_ntt_coset_interpolation_nm() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_ntt_coset_interpolation_nm::<$field>();
}

#[test]
#[parallel]
fn test_ntt_arbitrary_coset() {
Expand Down
56 changes: 56 additions & 0 deletions wrappers/rust/icicle-core/src/ntt/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,62 @@ where
}
}

pub fn check_ntt_coset_interpolation_nm<F: FieldImpl + ArkConvertible>()
where
F::ArkEquivalent: FftField,
<F as FieldImpl>::Config: NTT<F, F> + GenerateRandom<F>,
{
let test_sizes = [1 << 9, 1 << 10, 1 << 11, 1 << 13, 1 << 14, 1 << 16];
for test_size in test_sizes {
let test_size_rou = F::ArkEquivalent::get_root_of_unity((test_size << 1) as u64).unwrap();
let coset_generators = [F::from_ark(test_size_rou), F::Config::generate_random(1)[0]];

let scalars: Vec<F> = F::Config::generate_random(test_size);

let ark_domain = GeneralEvaluationDomain::<F::ArkEquivalent>::new(test_size).unwrap();

for coset_gen in coset_generators {
// (1) intt from evals to coeffs
let mut config = NTTConfig::default();
config.ordering = Ordering::kNM;
config.ntt_algorithm = NttAlgorithm::MixedRadix;

let mut intt_result = vec![F::zero(); test_size];
let intt_result = HostSlice::from_mut_slice(&mut intt_result);
ntt(HostSlice::from_slice(&scalars), NTTDir::kInverse, &config, intt_result).unwrap();

let mut ark_scalars = scalars
.iter()
.map(|v| v.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
ark_domain.ifft_in_place(&mut ark_scalars);

// (2) coset-ntt (compute coset evals)
config.coset_gen = coset_gen;
config.ordering = Ordering::kMN;
let mut coset_evals = vec![F::zero(); test_size];
ntt(
intt_result,
NTTDir::kForward,
&config,
HostSlice::from_mut_slice(&mut coset_evals),
)
.unwrap();

let ark_coset_domain = ark_domain
.get_coset(coset_gen.to_ark())
.unwrap();
ark_coset_domain.fft_in_place(&mut ark_scalars); // to reuse in next iteration

let coest_evals_as_ark = coset_evals
.iter()
.map(|v| v.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(coest_evals_as_ark, ark_scalars);
}
}
}

pub fn check_ntt_arbitrary_coset<F: FieldImpl + ArkConvertible>()
where
F::ArkEquivalent: FftField + ArkField,
Expand Down

0 comments on commit 36e288c

Please sign in to comment.