From 78e1c55b1a1800754b3e96509b361c181408afb6 Mon Sep 17 00:00:00 2001
From: Gabriel Santamaria <gaby.santamaria@outlook.fr>
Date: Sun, 10 Nov 2024 22:02:48 +0100
Subject: [PATCH] Adding asserts and managing negative axis in the FFT module.

---
 src/owl/fftpack/owl_fft_generic.ml | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml
index 0b4ce3d02..4d24c219c 100644
--- a/src/owl/fftpack/owl_fft_generic.ml
+++ b/src/owl/fftpack/owl_fft_generic.ml
@@ -11,6 +11,7 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x =
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
   assert (axis < num_dims x);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads;
@@ -23,6 +24,7 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x =
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
   assert (axis < num_dims x);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads;
@@ -35,6 +37,7 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
   assert (axis < num_dims x);
   let s = shape x in
   s.(axis) <- (s.(axis) / 2) + 1;
@@ -50,6 +53,7 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
   assert (axis < num_dims x);
   let s = shape x in
   let _ =
@@ -73,12 +77,13 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
+  assert (axis < num_dims x);
   let ortho =
     match ortho with
     | Some o -> o
     | None -> if norm = 2 then true else false
   in
-  assert (axis < num_dims x);
   assert (ttype > 0 || ttype < 5);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads;
@@ -91,12 +96,13 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
+  assert (axis < num_dims x);
   let ortho =
     match ortho with
     | Some o -> o
     | None -> if norm = 2 then true else false
   in
-  assert (axis < num_dims x);
   assert (ttype > 0 || ttype < 5);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads;
@@ -109,12 +115,13 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
+  assert (axis < num_dims x);
   let ortho =
     match ortho with
     | Some o -> o
     | None -> if norm = 2 then true else false
   in
-  assert (axis < num_dims x);
   assert (ttype > 0 || ttype < 5);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads;
@@ -127,12 +134,13 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
     | Some a -> a
     | None -> num_dims x - 1
   in
+  let axis = if axis < 0 then num_dims x + axis else axis in
+  assert (axis < num_dims x);
   let ortho =
     match ortho with
     | Some o -> o
     | None -> if norm = 2 then true else false
   in
-  assert (axis < num_dims x);
   assert (ttype > 0 || ttype < 5);
   let y = empty (kind x) (shape x) in
   Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads;