From 78e1c55b1a1800754b3e96509b361c181408afb6 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria <gaby.santamaria@outlook.fr> Date: Sun, 10 Nov 2024 22:02:48 +0100 Subject: [PATCH] Adding asserts and managing negative axis in the FFT module. --- src/owl/fftpack/owl_fft_generic.ml | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index 0b4ce3d02..4d24c219c 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -11,6 +11,7 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads; @@ -23,6 +24,7 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads; @@ -35,6 +37,7 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in s.(axis) <- (s.(axis) / 2) + 1; @@ -50,6 +53,7 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in let _ = @@ -73,12 +77,13 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); let ortho = match ortho with | Some o -> o | None -> if norm = 2 then true else false in - assert (axis < num_dims x); assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads; @@ -91,12 +96,13 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); let ortho = match ortho with | Some o -> o | None -> if norm = 2 then true else false in - assert (axis < num_dims x); assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads; @@ -109,12 +115,13 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); let ortho = match ortho with | Some o -> o | None -> if norm = 2 then true else false in - assert (axis < num_dims x); assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads; @@ -127,12 +134,13 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads | Some a -> a | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); let ortho = match ortho with | Some o -> o | None -> if norm = 2 then true else false in - assert (axis < num_dims x); assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads;