-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/unsup multichan waveform dataset #532
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
from torch.nn import CrossEntropyLoss | ||
|
||
from lhotse import CutSet | ||
from lhotse.cut import Cut, MixedCut | ||
from lhotse.cut import Cut, MixedCut, PaddingCut | ||
from lhotse.utils import DEFAULT_PADDING_VALUE | ||
|
||
|
||
|
@@ -315,6 +315,44 @@ def collate_multi_channel_audio(cuts: CutSet) -> torch.Tensor: | |
return audio | ||
|
||
|
||
def remove_pad_tracks(cuts): | ||
for cut in cuts: | ||
tracks = cut.tracks | ||
tracks_nopad = [t for t in tracks if not isinstance(t.cut, PaddingCut)] | ||
cut.tracks = tracks_nopad | ||
return cuts | ||
|
||
|
||
def collate_multi_channel_audio_rmpad(cuts: CutSet) -> torch.Tensor: | ||
""" | ||
Load audio samples for all the cuts and return them as a batch in a torch tensor. | ||
The cuts have to be of type ``MixedCut`` and their tracks will be interpreted as individual channels. | ||
The output shape is ``(batch, channel, time)``. | ||
The cuts will be padded with silence if necessary. | ||
""" | ||
assert all(cut.has_recording for cut in cuts) | ||
assert all(isinstance(cut, MixedCut) for cut in cuts) | ||
|
||
# TODO: how to ensure that each track is synced across batches? i.e. dim=1 is the track index | ||
# and should correspond to the same mic across batches | ||
|
||
cuts = maybe_pad(cuts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be needed as you're manually zero-padding later. |
||
cuts = remove_pad_tracks(cuts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a pitfall here, what if a MixedCut looks like:
or any variation of the situation where padding is in between of two cuts. I don't think Lhotse would handle these situations well with your current code. Maybe you should try only removing the padding at the end (and beginning, but for that one you have to be careful about modifying the offsets on the remaining tracks). Rather than manually removing PaddingCuts, I suggest using |
||
|
||
# NOTE: what to do when the # of tracks is not the same across cuts, right now | ||
# this is zero-padding but that seems bad ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you won't escape zero-padding of examples with less channels if you need to collate the data. However, I suggest you modify this function to return a 3-tuple: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But in the end your models will have to somehow deal with the non-meaningful channels anyway. As long as you're working on same-number-of-channels data, no need to overthink this. |
||
max_ntrack = max(len(cut.tracks) for cut in cuts) | ||
max_nsamp = max(cut.num_samples for cut in cuts) | ||
|
||
# NOTE: this ends up zero-padding! is that appropriate? | ||
audio = torch.zeros(len(cuts), max_ntrack, max_nsamp) | ||
for idx, cut in enumerate(cuts): | ||
ntrack = len(cut.tracks) | ||
nsamp = cut.num_samples | ||
audio[idx, 0:ntrack, 0:nsamp] = torch.from_numpy(cut.load_audio(mixed=False)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that if you did |
||
return audio | ||
|
||
|
||
def collate_vectors( | ||
tensors: Iterable[Union[torch.Tensor, np.ndarray]], | ||
padding_value: Union[int, float] = CrossEntropyLoss().ignore_index, | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||
from lhotse import validate | ||||||
from lhotse.augmentation import AugmentFn | ||||||
from lhotse.cut import CutSet | ||||||
from lhotse.dataset.collation import collate_audio, collate_features, collate_matrices | ||||||
from lhotse.dataset.collation import collate_audio, collate_features, collate_matrices, collate_multi_channel_audio_rmpad | ||||||
from lhotse.features import FeatureExtractor | ||||||
|
||||||
|
||||||
|
@@ -74,6 +74,41 @@ def _validate(self, cuts: CutSet) -> None: | |||||
assert all(cut.has_recording for cut in cuts) | ||||||
|
||||||
|
||||||
class UnsupervisedMultiChanWaveformDataset(UnsupervisedDataset): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
somehow reads better to me |
||||||
""" | ||||||
A variant of UnsupervisedDataset that provides waveform samples instead of features. | ||||||
The output is a tensor of shape (C, T), with C being the number of channels and T the number of audio samples. | ||||||
In this implementation, there will always be a single channel. | ||||||
Returns: | ||||||
.. code-block:: | ||||||
{ | ||||||
'audio': (B x NumSamples) float tensor | ||||||
'audio_lens': (B, ) int tensor | ||||||
} | ||||||
""" | ||||||
|
||||||
def __init__(self, collate: bool = True) -> None: | ||||||
super().__init__() | ||||||
self.collate = collate | ||||||
|
||||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: | ||||||
if self.collate: | ||||||
audio = collate_multi_channel_audio_rmpad(cuts) | ||||||
audio_lens = 0 # TODO | ||||||
return { | ||||||
"cuts": cuts, | ||||||
"audio": audio, | ||||||
"audio_lens": audio_lens, | ||||||
} | ||||||
else: | ||||||
return {"cuts": cuts, "audio": [c.load_audio(mixed=False) for c in cuts]} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line would again have the extra padding channels problem. This suggests that maybe the solution should not be (entirely) in the |
||||||
|
||||||
def _validate(self, cuts: CutSet) -> None: | ||||||
validate(cuts) | ||||||
assert all(cut.has_recording for cut in cuts) | ||||||
|
||||||
|
||||||
|
||||||
class DynamicUnsupervisedDataset(UnsupervisedDataset): | ||||||
""" | ||||||
An example dataset that shows how to use on-the-fly feature extraction in Lhotse. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ensure the tracks are sorted by some property; I imagine this is something very corpus specific and should be done by the user, not by the library.