Skip to content

Commit

Permalink
Little code change for more similarities with scipy's implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
gabyfle committed Nov 19, 2024
1 parent 2ab8c8d commit 8fa2d85
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/owl/fftpack/owl_fft_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type tnorm =

let tnorm_to_int = function
| Backward -> 0
| Forward -> 1
| Ortho -> 2
| Forward -> 2
| Ortho -> 1

let fft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) x =
let axis =
Expand Down
46 changes: 23 additions & 23 deletions src/owl/fftpack/owl_fftpack_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,34 @@

using namespace pocketfft::detail;

using ldbl_t = typename std::conditional<
sizeof(long double) == sizeof(double), double, long double>::type;

template <typename T>
T norm_fct(int inorm, size_t N)
{
switch (inorm)
{
case 0: // "backward" - no normalization for forward transform
if (inorm == 0)
return T(1);
case 1: // "forward" - 1/n normalization for forward transform
return T(1) / T(N);
case 2: // "ortho" - 1/sqrt(n) normalization for both directions
return T(1) / std::sqrt(T(N));
default:
caml_failwith("invalid value for inorm (must be 0, 1, or 2)");
// This will never be reached
return T(0);
}
if (inorm == 2)
return T(1 / ldbl_t(N));
if (inorm == 1)
return T(1 / sqrt(ldbl_t(N)));
caml_failwith("invalid value for norm (must be 0, 1, or 2)"); // could make use of caml exections
// This will never be reached
return T(0);
}

template <typename T>
T compute_norm_factor(const shape_t &dims, const shape_t &axes, int inorm, size_t fct = 1, int delta = 0)
T norm_fct(int inorm, const shape_t &shape,
const shape_t &axes, size_t fct = 1, int delta = 0)
{
if (inorm == 0)
return T(1);

size_t N = 1;
for (auto a : axes)
{
N *= fct * size_t(int64_t(dims[a]) + delta);
}
N *= fct * size_t(int64_t(shape[a]) + delta);

return norm_fct<T>(inorm, N);
}

Expand Down Expand Up @@ -95,7 +95,7 @@ value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = compute_norm_factor<Treal>(dims, axes, norm);
Treal norm_factor = norm_fct<Treal>(norm, dims, axes);
try
{
pocketfft::c2c(dims, stride_in, stride_out, axes, forward,
Expand Down Expand Up @@ -166,7 +166,7 @@ value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads)

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = compute_norm_factor<Treal>(dims, axes, norm);
Treal norm_factor = norm_fct<Treal>(norm, dims, axes);
try
{
pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD,
Expand Down Expand Up @@ -239,7 +239,7 @@ value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads)

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = compute_norm_factor<Treal>(dims, axes, norm);
Treal norm_factor = norm_fct<Treal>(norm, dims, axes);
try
{
pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD,
Expand Down Expand Up @@ -299,8 +299,8 @@ value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vO

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = (type == 1) ? compute_norm_factor<Treal>(dims, axes, norm, 2, -1)
: compute_norm_factor<Treal>(dims, axes, norm, 2);
Treal norm_factor = (type == 1) ? norm_fct<Treal>(norm, dims, axes, 2, -1)
: norm_fct<Treal>(norm, dims, axes, 2);
try
{
pocketfft::dct(dims, stride_in, stride_out, axes, type,
Expand Down Expand Up @@ -365,8 +365,8 @@ value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vO

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = (type == 1) ? compute_norm_factor<Treal>(dims, axes, norm, 2, 1)
: compute_norm_factor<Treal>(dims, axes, norm, 2);
Treal norm_factor = (type == 1) ? norm_fct<Treal>(norm, dims, axes, 2, 1)
: norm_fct<Treal>(norm, dims, axes, 2);
try
{
pocketfft::dst(dims, stride_in, stride_out, axes, type,
Expand Down

0 comments on commit 8fa2d85

Please sign in to comment.