diff --git a/icicle/include/ntt/ntt_impl.cuh b/icicle/include/ntt/ntt_impl.cuh index 986436abc..5f5bf1e29 100644 --- a/icicle/include/ntt/ntt_impl.cuh +++ b/icicle/include/ntt/ntt_impl.cuh @@ -32,6 +32,7 @@ namespace mxntt { S* external_twiddles, S* internal_twiddles, S* basic_twiddles, + S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode int ntt_size, int max_logn, int batch_size, diff --git a/icicle/src/ntt/kernel_ntt.cu b/icicle/src/ntt/kernel_ntt.cu index 0c63fe081..3166b334c 100644 --- a/icicle/src/ntt/kernel_ntt.cu +++ b/icicle/src/ntt/kernel_ntt.cu @@ -134,14 +134,23 @@ namespace mxntt { int n_scalars, uint32_t log_size, eRevType rev_type, - bool dit, + bool fast_tw, E* out_vec) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= size * batch_size) return; int64_t scalar_id = (tid / columns_batch_size) % size; - if (rev_type != eRevType::None) - scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type); + if (rev_type != eRevType::None) { + // Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the + // scalars the same way (then multiply element-wise). This would be a DIT-digit-reverse shuffle. (this is + // confusing but) BUT to avoid shuffling the scalars, we instead want to ask which element in the non-shuffled + // vec is now placed at index tid, which is the opposite of a DIT-digit-reverse --> this is the DIF-digit-reverse. + // Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the + // corresponding element in scalars vec. + const bool dif = rev_type == eRevType::NaturalToMixedRev; + scalar_id = + generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type); + } out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; } @@ -903,6 +912,7 @@ namespace mxntt { S* external_twiddles, S* internal_twiddles, S* basic_twiddles, + S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode int ntt_size, int max_logn, int batch_size, @@ -958,8 +968,8 @@ namespace mxntt { if (is_on_coset && !is_inverse) { batch_elementwise_mul_with_reorder_kernel<<>>( d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1, - arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, - reverse_coset, dit, d_output); + arbitrary_coset ? arbitrary_coset : linear_twiddle, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, + reverse_coset, fast_tw, d_output); d_input = d_output; } @@ -991,8 +1001,8 @@ namespace mxntt { if (is_on_coset && is_inverse) { batch_elementwise_mul_with_reorder_kernel<<>>( d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1, - arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index, - n_twiddles, logn, reverse_coset, dit, d_output); + arbitrary_coset ? arbitrary_coset : linear_twiddle + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index, + n_twiddles, logn, reverse_coset, fast_tw, d_output); } return CHK_LAST(); @@ -1021,6 +1031,8 @@ namespace mxntt { scalar_t* external_twiddles, scalar_t* internal_twiddles, scalar_t* basic_twiddles, + scalar_t* linear_twiddles, + int ntt_size, int max_logn, int batch_size, @@ -1039,6 +1051,8 @@ namespace mxntt { scalar_t* external_twiddles, scalar_t* internal_twiddles, scalar_t* basic_twiddles, + scalar_t* linear_twiddles, + int ntt_size, int max_logn, int batch_size, diff --git a/icicle/src/ntt/ntt.cu b/icicle/src/ntt/ntt.cu index 46781bf44..dce454c3a 100644 --- a/icicle/src/ntt/ntt.cu +++ b/icicle/src/ntt/ntt.cu @@ -717,8 +717,7 @@ namespace ntt { d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream)); } else { - const bool is_on_coset = (coset_index != 0) || coset; - const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset; + const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr); S* twiddles = is_fast_twiddles_enabled ? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles) : domain.twiddles; @@ -728,9 +727,11 @@ namespace ntt { S* basic_twiddles = is_fast_twiddles_enabled ? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles) : domain.basic_twiddles; + S* linear_twiddles = domain.twiddles; // twiddles organized as [1,w,w^2,...] CHK_IF_RETURN(mxntt::mixed_radix_ntt( - d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size, - config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream)); + d_input, d_output, twiddles, internal_twiddles, basic_twiddles, linear_twiddles, size, domain.max_log_size, + batch_size, config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, + stream)); } } diff --git a/wrappers/rust/icicle-core/src/ntt/mod.rs b/wrappers/rust/icicle-core/src/ntt/mod.rs index c1df383ea..d343a2dd9 100644 --- a/wrappers/rust/icicle-core/src/ntt/mod.rs +++ b/wrappers/rust/icicle-core/src/ntt/mod.rs @@ -378,6 +378,13 @@ macro_rules! impl_ntt_tests { check_ntt_coset_from_subgroup::<$field>() } + #[test] + #[parallel] + fn test_ntt_coset_interpolation_nm() { + INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE)); + check_ntt_coset_interpolation_nm::<$field>(); + } + #[test] #[parallel] fn test_ntt_arbitrary_coset() { diff --git a/wrappers/rust/icicle-core/src/ntt/tests.rs b/wrappers/rust/icicle-core/src/ntt/tests.rs index 878013ea2..7e41e363c 100644 --- a/wrappers/rust/icicle-core/src/ntt/tests.rs +++ b/wrappers/rust/icicle-core/src/ntt/tests.rs @@ -190,6 +190,62 @@ where } } +pub fn check_ntt_coset_interpolation_nm() +where + F::ArkEquivalent: FftField, + ::Config: NTT + GenerateRandom, +{ + let test_sizes = [1 << 9, 1 << 10, 1 << 11, 1 << 13, 1 << 14, 1 << 16]; + for test_size in test_sizes { + let test_size_rou = F::ArkEquivalent::get_root_of_unity((test_size << 1) as u64).unwrap(); + let coset_generators = [F::from_ark(test_size_rou), F::Config::generate_random(1)[0]]; + + let scalars: Vec = F::Config::generate_random(test_size); + + let ark_domain = GeneralEvaluationDomain::::new(test_size).unwrap(); + + for coset_gen in coset_generators { + // (1) intt from evals to coeffs + let mut config = NTTConfig::default(); + config.ordering = Ordering::kNM; + config.ntt_algorithm = NttAlgorithm::MixedRadix; + + let mut intt_result = vec![F::zero(); test_size]; + let intt_result = HostSlice::from_mut_slice(&mut intt_result); + ntt(HostSlice::from_slice(&scalars), NTTDir::kInverse, &config, intt_result).unwrap(); + + let mut ark_scalars = scalars + .iter() + .map(|v| v.to_ark()) + .collect::>(); + ark_domain.ifft_in_place(&mut ark_scalars); + + // (2) coset-ntt (compute coset evals) + config.coset_gen = coset_gen; + config.ordering = Ordering::kMN; + let mut coset_evals = vec![F::zero(); test_size]; + ntt( + intt_result, + NTTDir::kForward, + &config, + HostSlice::from_mut_slice(&mut coset_evals), + ) + .unwrap(); + + let ark_coset_domain = ark_domain + .get_coset(coset_gen.to_ark()) + .unwrap(); + ark_coset_domain.fft_in_place(&mut ark_scalars); // to reuse in next iteration + + let coest_evals_as_ark = coset_evals + .iter() + .map(|v| v.to_ark()) + .collect::>(); + assert_eq!(coest_evals_as_ark, ark_scalars); + } + } +} + pub fn check_ntt_arbitrary_coset() where F::ArkEquivalent: FftField + ArkField,