Skip to content

Commit

Permalink
Adding asserts and managing negative axis in the FFT module.
Browse files Browse the repository at this point in the history
  • Loading branch information
gabyfle committed Nov 10, 2024
1 parent 91fec9b commit 78e1c55
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/owl/fftpack/owl_fft_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x =
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads;
Expand All @@ -23,6 +24,7 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x =
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads;
Expand All @@ -35,6 +37,7 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let s = shape x in
s.(axis) <- (s.(axis) / 2) + 1;
Expand All @@ -50,6 +53,7 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let s = shape x in
let _ =
Expand All @@ -73,12 +77,13 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
in
assert (axis < num_dims x);
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads;
Expand All @@ -91,12 +96,13 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
in
assert (axis < num_dims x);
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads;
Expand All @@ -109,12 +115,13 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
in
assert (axis < num_dims x);
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads;
Expand All @@ -127,12 +134,13 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
| Some a -> a
| None -> num_dims x - 1
in
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
in
assert (axis < num_dims x);
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads;
Expand Down

0 comments on commit 78e1c55

Please sign in to comment.