Skip to content

Commit

Permalink
Merge pull request #178 from EamonnCarson/colored_noise_edgecase_shap…
Browse files Browse the repository at this point in the history
…e_error

Bugfix: AddColoredNoise edge-case shape error
  • Loading branch information
iver56 authored Sep 2, 2024
2 parents 362b0e0 + 9cb06d4 commit 2c6a272
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions tests/test_colored_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def test_colored_noise_guaranteed_with_single_tensor(self):
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))

def test_colored_noise_guaranteed_with_single_tensor_edgecase_sample_rate(self):
signal = torch.zeros(1, 1, 16001)
mixed_input = self.cl_noise_transform_guaranteed(
signal, 16001
).samples
self.assertFalse(torch.equal(mixed_input, self.input_audio))
self.assertEqual(mixed_input.size(0), self.input_audio.size(0))
self.assertEqual(mixed_input.size(1), self.input_audio.size(1))

def test_colored_noise_guaranteed_with_batched_tensor(self):
random.seed(42)
mixed_inputs = self.cl_noise_transform_guaranteed(
Expand Down
2 changes: 1 addition & 1 deletion torch_audiomentations/augmentations/colored_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _gen_noise(f_decay, num_samples, sample_rate, device):
)
spec *= mask
noise = Audio.rms_normalize(irfft(spec).unsqueeze(0)).squeeze()
noise = torch.cat([noise] * int(ceil(num_samples / sample_rate)))
noise = torch.cat([noise] * int(ceil(num_samples / noise.shape[0])))
return noise[:num_samples]


Expand Down

0 comments on commit 2c6a272

Please sign in to comment.