Skip to content

Commit

Permalink
Add variable audio duration in datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
amanteur committed Jan 25, 2025
1 parent a5ab6d2 commit 0520ea1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
20 changes: 16 additions & 4 deletions oml/datasets/audios.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -82,7 +83,9 @@ def __init__(
len(paths) == len(start_times)
), "The length of 'start_times' must match the length of 'paths' if 'start_times' is provided."
assert sample_rate > 0, "The sample rate must be a positive integer."
assert max_num_seconds > 0, "The maximum number of seconds must be a positive float."
assert (
max_num_seconds is None or max_num_seconds > 0
), "The maximum number of seconds must be None or a positive float."
assert isinstance(convert_to_mono, bool), "'convert_to_mono' must be a boolean."

paths = [Path(p) if dataset_root is None else Path(dataset_root) / p for p in paths]
Expand All @@ -103,7 +106,15 @@ def __init__(

self._paths = paths
self._sample_rate = sample_rate
self._num_frames = int(max_num_seconds * sample_rate)

if max_num_seconds is None:
warnings.warn(
"max_num_seconds is None, so the audio files will not be trimmed or padded. "
"Additional collate_fn is required to handle different audio lengths."
)
self._num_frames = None
else:
self._num_frames = int(max_num_seconds * sample_rate)
self._convert_to_mono = convert_to_mono
self._frame_offsets = (
[int(st * sample_rate) for st in start_times] if start_times is not None else [0] * len(paths)
Expand Down Expand Up @@ -165,9 +176,10 @@ def get_audio(self, item: int) -> FloatTensor:
import torchaudio

path = self._paths[item]
audio, sample_rate = torchaudio.load(path)
audio, sample_rate = torchaudio.load(str(path))
audio = self._downmix_and_resample(audio, sample_rate)
audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames)
if self._num_frames is not None:
audio = self._trim_or_pad(audio, self._frame_offsets[item], self._num_frames)
return audio

def __getitem__(self, item: int) -> Dict[str, Union[FloatTensor, int]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_oml/test_models/test_audio_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_extractor(constructor: IExtractor, args: Dict[str, Any]) -> None:
def test_checkpoints_from_zoo(constructor: IExtractor, weights: str) -> None:

df_train, _ = get_mock_audios_dataset(global_paths=True)
dataset = AudioLabeledDataset(df_train)
dataset = AudioLabeledDataset(df_train, max_num_seconds=None)

model = constructor.from_pretrained(weights).eval()
emb1 = model.extract(dataset[0]["input_tensors"])
Expand Down

0 comments on commit 0520ea1

Please sign in to comment.