Skip to content

Commit

Permalink
Allow dynamic shape in STFTSpectrogram layer. (#20736)
Browse files Browse the repository at this point in the history
by simply using `ops.shape(x)` instead of `x.shape`.
  • Loading branch information
hertschuh authored Jan 8, 2025
1 parent 26e71f5 commit f97be63
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
10 changes: 5 additions & 5 deletions keras/src/layers/preprocessing/stft_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def build(self, input_shape):
self.built = True

def _adjust_shapes(self, outputs):
_, channels, freq_channels, time_seq = outputs.shape
_, channels, freq_channels, time_seq = ops.shape(outputs)
batch_size = -1
if self.data_format == "channels_last":
if self.expand_dims:
Expand All @@ -258,11 +258,11 @@ def _adjust_shapes(self, outputs):

def _apply_conv(self, inputs, kernel):
if self.data_format == "channels_last":
_, time_seq, channels = inputs.shape
_, time_seq, channels = ops.shape(inputs)
inputs = ops.transpose(inputs, [0, 2, 1])
inputs = ops.reshape(inputs, [-1, time_seq, 1])
else:
_, channels, time_seq = inputs.shape
_, channels, time_seq = ops.shape(inputs)
inputs = ops.reshape(inputs, [-1, 1, time_seq])

outputs = ops.conv(
Expand All @@ -274,14 +274,14 @@ def _apply_conv(self, inputs, kernel):
)
batch_size = -1
if self.data_format == "channels_last":
_, time_seq, freq_channels = outputs.shape
_, time_seq, freq_channels = ops.shape(outputs)
outputs = ops.transpose(outputs, [0, 2, 1])
outputs = ops.reshape(
outputs,
[batch_size, channels, freq_channels, time_seq],
)
else:
_, freq_channels, time_seq = outputs.shape
_, freq_channels, time_seq = ops.shape(outputs)
outputs = ops.reshape(
outputs,
[batch_size, channels, freq_channels, time_seq],
Expand Down
29 changes: 26 additions & 3 deletions keras/src/layers/preprocessing/stft_spectrogram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,30 @@ def test_spectrogram_basics(self):
supports_masking=False,
)

@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Backend does not support dynamic shapes",
)
def test_spectrogram_dynamic_shape(self):
model = Sequential(
[
Input(shape=(None, 1), dtype=TestSpectrogram.DTYPE),
layers.STFTSpectrogram(
frame_length=500,
frame_step=25,
fft_length=1024,
mode="stft",
data_format="channels_last",
),
]
)

def generator():
yield (np.random.random((2, 16000, 1)),)
yield (np.random.random((3, 8000, 1)),)

model.predict(generator())

@pytest.mark.requires_trainable_backend
def test_spectrogram_error(self):
rnd = np.random.RandomState(41)
Expand Down Expand Up @@ -310,10 +334,9 @@ def test_spectrogram_error(self):
init_args["mode"] = "angle"
y_true, y = self._calc_spectrograms(x, **init_args)

pi = np.arccos(np.float128(-1)).astype(y_true.dtype)
mask = np.isclose(y, y_true, **tol_kwargs)
mask |= np.isclose(y + 2 * pi, y_true, **tol_kwargs)
mask |= np.isclose(y - 2 * pi, y_true, **tol_kwargs)
mask |= np.isclose(y + 2 * np.pi, y_true, **tol_kwargs)
mask |= np.isclose(y - 2 * np.pi, y_true, **tol_kwargs)
mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs)
mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs)

Expand Down

0 comments on commit f97be63

Please sign in to comment.