From 0520ea1cda9f44abc90d5d33db6ce336355ba7d8 Mon Sep 17 00:00:00 2001 From: amanturamatov <amatoamant@gmail.com> Date: Sun, 26 Jan 2025 02:06:53 +0600 Subject: [PATCH] Add variable audio duration in datasets --- oml/datasets/audios.py | 20 +++++++++++++++---- .../test_oml/test_models/test_audio_models.py | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/oml/datasets/audios.py b/oml/datasets/audios.py index 05c6f865..c6d3dbd2 100644 --- a/oml/datasets/audios.py +++ b/oml/datasets/audios.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -82,7 +83,9 @@ def __init__( len(paths) == len(start_times) ), "The length of 'start_times' must match the length of 'paths' if 'start_times' is provided." assert sample_rate > 0, "The sample rate must be a positive integer." - assert max_num_seconds > 0, "The maximum number of seconds must be a positive float." + assert ( + max_num_seconds is None or max_num_seconds > 0 + ), "The maximum number of seconds must be None or a positive float." assert isinstance(convert_to_mono, bool), "'convert_to_mono' must be a boolean." paths = [Path(p) if dataset_root is None else Path(dataset_root) / p for p in paths] @@ -103,7 +106,15 @@ def __init__( self._paths = paths self._sample_rate = sample_rate - self._num_frames = int(max_num_seconds * sample_rate) + + if max_num_seconds is None: + warnings.warn( + "max_num_seconds is None, so the audio files will not be trimmed or padded. " + "Additional collate_fn is required to handle different audio lengths." + ) + self._num_frames = None + else: + self._num_frames = int(max_num_seconds * sample_rate) self._convert_to_mono = convert_to_mono self._frame_offsets = ( [int(st * sample_rate) for st in start_times] if start_times is not None else [0] * len(paths) @@ -165,9 +176,10 @@ def get_audio(self, item: int) -> FloatTensor: import torchaudio path = self._paths[item] - audio, sample_rate = torchaudio.load(path) + audio, sample_rate = torchaudio.load(str(path)) audio = self._downmix_and_resample(audio, sample_rate) - audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames) + if self._num_frames is not None: + audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames) return audio def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]: diff --git a/tests/test_oml/test_models/test_audio_models.py b/tests/test_oml/test_models/test_audio_models.py index fc7bd6df..92c27379 100644 --- a/tests/test_oml/test_models/test_audio_models.py +++ b/tests/test_oml/test_models/test_audio_models.py @@ -48,7 +48,7 @@ def test_extractor(constructor: IExtractor, args: Dict[str, Any]) -> None: def test_checkpoints_from_zoo(constructor: IExtractor, weights: str) -> None: df_train, _ = get_mock_audios_dataset(global_paths=True) - dataset = AudioLabeledDataset(df_train) + dataset = AudioLabeledDataset(df_train, max_num_seconds=None) model = constructor.from_pretrained(weights).eval() emb1 = model.extract(dataset[0]["input_tensors"])