From 0520ea1cda9f44abc90d5d33db6ce336355ba7d8 Mon Sep 17 00:00:00 2001
From: amanturamatov <amatoamant@gmail.com>
Date: Sun, 26 Jan 2025 02:06:53 +0600
Subject: [PATCH] Add variable audio duration in datasets

---
 oml/datasets/audios.py                        | 20 +++++++++++++++----
 .../test_oml/test_models/test_audio_models.py |  2 +-
 2 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/oml/datasets/audios.py b/oml/datasets/audios.py
index 05c6f865..c6d3dbd2 100644
--- a/oml/datasets/audios.py
+++ b/oml/datasets/audios.py
@@ -1,3 +1,4 @@
+import warnings
 from pathlib import Path
 from typing import Any, Callable, Dict, List, Optional, Union
 
@@ -82,7 +83,9 @@ def __init__(
             len(paths) == len(start_times)
         ), "The length of 'start_times' must match the length of 'paths' if 'start_times' is provided."
         assert sample_rate > 0, "The sample rate must be a positive integer."
-        assert max_num_seconds > 0, "The maximum number of seconds must be a positive float."
+        assert (
+            max_num_seconds is None or max_num_seconds > 0
+        ), "The maximum number of seconds must be None or a positive float."
         assert isinstance(convert_to_mono, bool), "'convert_to_mono' must be a boolean."
 
         paths = [Path(p) if dataset_root is None else Path(dataset_root) / p for p in paths]
@@ -103,7 +106,15 @@ def __init__(
 
         self._paths = paths
         self._sample_rate = sample_rate
-        self._num_frames = int(max_num_seconds * sample_rate)
+
+        if max_num_seconds is None:
+            warnings.warn(
+                "max_num_seconds is None, so the audio files will not be trimmed or padded. "
+                "Additional collate_fn is required to handle different audio lengths."
+            )
+            self._num_frames = None
+        else:
+            self._num_frames = int(max_num_seconds * sample_rate)
         self._convert_to_mono = convert_to_mono
         self._frame_offsets = (
             [int(st * sample_rate) for st in start_times] if start_times is not None else [0] * len(paths)
@@ -165,9 +176,10 @@ def get_audio(self, item: int) -> FloatTensor:
         import torchaudio
 
         path = self._paths[item]
-        audio, sample_rate = torchaudio.load(path)
+        audio, sample_rate = torchaudio.load(str(path))
         audio = self._downmix_and_resample(audio, sample_rate)
-        audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames)
+        if self._num_frames is not None:
+            audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames)
         return audio
 
     def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]:
diff --git a/tests/test_oml/test_models/test_audio_models.py b/tests/test_oml/test_models/test_audio_models.py
index fc7bd6df..92c27379 100644
--- a/tests/test_oml/test_models/test_audio_models.py
+++ b/tests/test_oml/test_models/test_audio_models.py
@@ -48,7 +48,7 @@ def test_extractor(constructor: IExtractor, args: Dict[str, Any]) -> None:
 def test_checkpoints_from_zoo(constructor: IExtractor, weights: str) -> None:
 
     df_train, _ = get_mock_audios_dataset(global_paths=True)
-    dataset = AudioLabeledDataset(df_train)
+    dataset = AudioLabeledDataset(df_train, max_num_seconds=None)
 
     model = constructor.from_pretrained(weights).eval()
     emb1 = model.extract(dataset[0]["input_tensors"])