Skip to content

Commit

Permalink
Merge pull request #4 from clearmatics/exceptions-fix
Browse files Browse the repository at this point in the history
Exceptions fix - merge scipr-lab/libfqfft#12 (depends on #6)
  • Loading branch information
AntoineRondelet authored Oct 13, 2021
2 parents 249c88a + 4877aed commit 620cef1
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace libfqfft {
FieldT arithmetic_generator;
void do_precomputation();

static bool valid_for_size(const size_t m);

arithmetic_sequence_domain(const size_t m);

void FFT(std::vector<FieldT> &a);
Expand Down
20 changes: 17 additions & 3 deletions libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@

namespace libfqfft {

template<typename FieldT>
bool arithmetic_sequence_domain<FieldT>::valid_for_size(const size_t m)
{
if (m <= 1) {
return false;
}

if (FieldT::arithmetic_generator() == FieldT::zero()) {
return false;
}

return true;
}

template<typename FieldT>
arithmetic_sequence_domain<FieldT>::arithmetic_sequence_domain(const size_t m) : evaluation_domain<FieldT>(m)
{
Expand All @@ -42,7 +56,7 @@ void arithmetic_sequence_domain<FieldT>::FFT(std::vector<FieldT> &a)

/* Monomial to Newton */
monomial_to_newton_basis(a, this->subproduct_tree, this->m);

/* Newton to Evaluation */
std::vector<FieldT> S(this->m); /* i! * arithmetic_generator */
S[0] = FieldT::one();
Expand Down Expand Up @@ -70,7 +84,7 @@ template<typename FieldT>
void arithmetic_sequence_domain<FieldT>::iFFT(std::vector<FieldT> &a)
{
if (a.size() != this->m) throw DomainSizeException("arithmetic: expected a.size() == this->m");

if (!this->precomputation_sentinel) do_precomputation();

/* Interpolation to Newton */
Expand Down Expand Up @@ -152,7 +166,7 @@ std::vector<FieldT> arithmetic_sequence_domain<FieldT>::evaluate_all_lagrange_po

std::vector<FieldT> w(this->m);
w[0] = g_vanish.inverse() * (this->arithmetic_generator^(this->m-1));

l[0] = l_vanish * l[0].inverse() * w[0];
for (size_t i = 1; i < this->m; i++)
{
Expand Down
2 changes: 2 additions & 0 deletions libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class basic_radix2_domain : public evaluation_domain<FieldT> {

FieldT omega;

static bool valid_for_size(const size_t m);

basic_radix2_domain(const size_t m);

void FFT(std::vector<FieldT> &a);
Expand Down
14 changes: 14 additions & 0 deletions libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@

namespace libfqfft {

template<typename FieldT>
bool basic_radix2_domain<FieldT>::valid_for_size(const size_t m)
{
if (m <= 1) {
return false;
}

if (!libff::has_root_of_unity<FieldT>(m)) {
return false;
}

return true;
}

template<typename FieldT>
basic_radix2_domain<FieldT>::basic_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
{
Expand Down
2 changes: 2 additions & 0 deletions libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class extended_radix2_domain : public evaluation_domain<FieldT> {
FieldT omega;
FieldT shift;

static bool valid_for_size(const size_t m);

extended_radix2_domain(const size_t m);

void FFT(std::vector<FieldT> &a);
Expand Down
26 changes: 26 additions & 0 deletions libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@

namespace libfqfft {

template<typename FieldT>
bool extended_radix2_domain<FieldT>::valid_for_size(const size_t m)
{
if (m <= 1) {
return false;
}

// Will `get_root_of_unity` throw?
if (!std::is_same<FieldT, libff::Double>::value)
{
const size_t logm = libff::log2(m);

if (logm != (FieldT::s + 1)) {
return false;
}
}

size_t small_m = m / 2;

if (!libff::has_root_of_unity<FieldT>(small_m)) {
return false;
}

return true;
}

template<typename FieldT>
extended_radix2_domain<FieldT>::extended_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace libfqfft {
std::vector<FieldT> geometric_triangular_sequence;
void do_precomputation();

static bool valid_for_size(const size_t m);

geometric_sequence_domain(const size_t m);

void FFT(std::vector<FieldT> &a);
Expand Down
20 changes: 17 additions & 3 deletions libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,33 @@

namespace libfqfft {

template<typename FieldT>
bool geometric_sequence_domain<FieldT>::valid_for_size(const size_t m)
{
if (m <= 1) {
return false;
}

if (FieldT::geometric_generator() == FieldT::zero()) {
return false;
}

return true;
}

template<typename FieldT>
geometric_sequence_domain<FieldT>::geometric_sequence_domain(const size_t m) : evaluation_domain<FieldT>(m)
{
if (m <= 1) throw InvalidSizeException("geometric(): expected m > 1");
if (FieldT::geometric_generator() == FieldT::zero())
throw InvalidSizeException("geometric(): expected FieldT::geometric_generator() != FieldT::zero()");

precomputation_sentinel = 0;
}

template<typename FieldT>
void geometric_sequence_domain<FieldT>::FFT(std::vector<FieldT> &a)
{
{
if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m");

if (!this->precomputation_sentinel) do_precomputation();
Expand Down Expand Up @@ -71,7 +85,7 @@ template<typename FieldT>
void geometric_sequence_domain<FieldT>::iFFT(std::vector<FieldT> &a)
{
if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m");

if (!this->precomputation_sentinel) do_precomputation();

/* Interpolation to Newton */
Expand Down
2 changes: 2 additions & 0 deletions libfqfft/evaluation_domain/domains/step_radix2_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class step_radix2_domain : public evaluation_domain<FieldT> {
FieldT big_omega;
FieldT small_omega;

static bool valid_for_size(const size_t m);

step_radix2_domain(const size_t m);

void FFT(std::vector<FieldT> &a);
Expand Down
29 changes: 28 additions & 1 deletion libfqfft/evaluation_domain/domains/step_radix2_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,33 @@

namespace libfqfft {

template<typename FieldT>
bool step_radix2_domain<FieldT>::valid_for_size(const size_t m)
{
if (m <= 1) {
return false;
}

const size_t big_m = 1ul<<(libff::log2(m)-1);
const size_t small_m = m - big_m;

if (small_m != 1ul<<libff::log2(small_m)) {
return false;
}

// omega
if (!libff::has_root_of_unity<FieldT>(1ul<<libff::log2(m))) {
return false;
}

// small_omega
if (!libff::has_root_of_unity<FieldT>(1ul<<libff::log2(small_m))) {
return false;
}

return true;
}

template<typename FieldT>
step_radix2_domain<FieldT>::step_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
{
Expand All @@ -30,7 +57,7 @@ step_radix2_domain<FieldT>::step_radix2_domain(const size_t m) : evaluation_doma

try { omega = libff::get_root_of_unity<FieldT>(1ul<<libff::log2(m)); }
catch (const std::invalid_argument& e) { throw DomainSizeException(e.what()); }

big_omega = omega.squared();
small_omega = libff::get_root_of_unity<FieldT>(small_m);
}
Expand Down
1 change: 1 addition & 0 deletions libfqfft/evaluation_domain/evaluation_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define EVALUATION_DOMAIN_HPP_

#include <vector>
#include <libff/common/double.hpp>

namespace libfqfft {

Expand Down
36 changes: 27 additions & 9 deletions libfqfft/evaluation_domain/get_evaluation_domain.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,33 @@ std::shared_ptr<evaluation_domain<FieldT> > get_evaluation_domain(const size_t m
const size_t small = min_size - big;
const size_t rounded_small = (1ul<<libff::log2(small));

try { result.reset(new basic_radix2_domain<FieldT>(min_size)); }
catch(...) { try { result.reset(new extended_radix2_domain<FieldT>(min_size)); }
catch(...) { try { result.reset(new step_radix2_domain<FieldT>(min_size)); }
catch(...) { try { result.reset(new basic_radix2_domain<FieldT>(big + rounded_small)); }
catch(...) { try { result.reset(new extended_radix2_domain<FieldT>(big + rounded_small)); }
catch(...) { try { result.reset(new step_radix2_domain<FieldT>(big + rounded_small)); }
catch(...) { try { result.reset(new geometric_sequence_domain<FieldT>(min_size)); }
catch(...) { try { result.reset(new arithmetic_sequence_domain<FieldT>(min_size)); }
catch(...) { throw DomainSizeException("get_evaluation_domain: no matching domain"); }}}}}}}}
if (basic_radix2_domain<FieldT>::valid_for_size(min_size)) {
result.reset(new basic_radix2_domain<FieldT>(min_size));
}
else if (extended_radix2_domain<FieldT>::valid_for_size(min_size)) {
result.reset(new extended_radix2_domain<FieldT>(min_size));
}
else if (step_radix2_domain<FieldT>::valid_for_size(min_size)) {
result.reset(new step_radix2_domain<FieldT>(min_size));
}
else if (basic_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
result.reset(new basic_radix2_domain<FieldT>(big + rounded_small));
}
else if (extended_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
result.reset(new extended_radix2_domain<FieldT>(big + rounded_small));
}
else if (step_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
result.reset(new step_radix2_domain<FieldT>(big + rounded_small));
}
else if (geometric_sequence_domain<FieldT>::valid_for_size(min_size)) {
result.reset(new geometric_sequence_domain<FieldT>(min_size));
}
else if (arithmetic_sequence_domain<FieldT>::valid_for_size(min_size)) {
result.reset(new arithmetic_sequence_domain<FieldT>(min_size));
}
else {
throw DomainSizeException("get_evaluation_domain: no matching domain");
}

return result;
}
Expand Down

0 comments on commit 620cef1

Please sign in to comment.