From 84ed1c3e7b658eb9fa11aa915313902082d0857f Mon Sep 17 00:00:00 2001 From: Dharshannan Sugunan <94248626+Dharshannan@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:08:22 +0100 Subject: [PATCH] stft_implementation # (#22578) --- ivy/functional/frontends/tensorflow/signal.py | 23 ++++++++ .../test_tensorflow/test_signal.py | 52 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/signal.py b/ivy/functional/frontends/tensorflow/signal.py index 51481b62fa5a4..137e08e8133f5 100644 --- a/ivy/functional/frontends/tensorflow/signal.py +++ b/ivy/functional/frontends/tensorflow/signal.py @@ -38,6 +38,29 @@ def kaiser_window(window_length, beta=12.0, dtype=ivy.float32, name=None): return ivy.kaiser_window(window_length, periodic=False, beta=beta, dtype=dtype) +# stft +@to_ivy_arrays_and_back +def stft( + signals, + frame_length, + frame_step, + fft_length=None, + window_fn=None, + pad_end=False, + name=None, +): + signals = ivy.asarray(signals) + return ivy.stft( + signals, + frame_length, + frame_step, + fft_length=fft_length, + window_fn=window_fn, + pad_end=pad_end, + name=name, + ) + + @with_supported_dtypes( {"2.13.0 and below": ("float16", "float32", "float64", "bfloat16")}, "tensorflow", diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py index e7d37ab6da16d..cbdfe053ae9fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py @@ -31,6 +31,24 @@ def _valid_idct(draw): return dtype, x, type, n, axis, norm +@st.composite +def _valid_stft(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + max_value=65280, + min_value=-65280, + min_num_dims=1, + min_dim_size=2, + shared_dtype=True, + ) + ) + frame_length = draw(helpers.ints(min_value=16, max_value=100)) + frame_step = draw(helpers.ints(min_value=1, max_value=50)) + + return dtype, x, frame_length, frame_step + + # --- Main --- # # ------------ # @@ -194,6 +212,40 @@ def test_tensorflow_kaiser_window( ) +# test stft +@handle_frontend_test( + fn_tree="tensorflow.signal.stft", + dtype_x_and_args=_valid_stft(), + test_with_out=st.just(False), +) +def test_tensorflow_stft( + *, + dtype_x_and_args, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, frame_length, frame_step = dtype_x_and_args + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + signals=x[0], + frame_length=frame_length, + frame_step=frame_step, + fft_length=None, + window_fn=None, + pad_end=True, + atol=1e-02, + rtol=1e-02, + ) + + # vorbis_window @handle_frontend_test( fn_tree="tensorflow.signal.vorbis_window",