Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/unsup multichan waveform dataset #532

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion lhotse/dataset/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import CrossEntropyLoss

from lhotse import CutSet
from lhotse.cut import Cut, MixedCut
from lhotse.cut import Cut, MixedCut, PaddingCut
from lhotse.utils import DEFAULT_PADDING_VALUE


Expand Down Expand Up @@ -315,6 +315,44 @@ def collate_multi_channel_audio(cuts: CutSet) -> torch.Tensor:
return audio


def remove_pad_tracks(cuts):
for cut in cuts:
tracks = cut.tracks
tracks_nopad = [t for t in tracks if not isinstance(t.cut, PaddingCut)]
cut.tracks = tracks_nopad
return cuts


def collate_multi_channel_audio_rmpad(cuts: CutSet) -> torch.Tensor:
"""
Load audio samples for all the cuts and return them as a batch in a torch tensor.
The cuts have to be of type ``MixedCut`` and their tracks will be interpreted as individual channels.
The output shape is ``(batch, channel, time)``.
The cuts will be padded with silence if necessary.
"""
assert all(cut.has_recording for cut in cuts)
assert all(isinstance(cut, MixedCut) for cut in cuts)

# TODO: how to ensure that each track is synced across batches? i.e. dim=1 is the track index
# and should correspond to the same mic across batches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ensure the tracks are sorted by some property; I imagine this is something very corpus specific and should be done by the user, not by the library.


cuts = maybe_pad(cuts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed as you're manually zero-padding later.

cuts = remove_pad_tracks(cuts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a pitfall here, what if a MixedCut looks like:

|-------cut1-------||---padding---||----cut2----|

or any variation of the situation where padding is in between of two cuts. I don't think Lhotse would handle these situations well with your current code. Maybe you should try only removing the padding at the end (and beginning, but for that one you have to be careful about modifying the offsets on the remaining tracks). Rather than manually removing PaddingCuts, I suggest using .truncate() with carefully computed offset and duration arguments; that method will handle a lot of pitfalls and edge-cases.


# NOTE: what to do when the # of tracks is not the same across cuts, right now
# this is zero-padding but that seems bad ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you won't escape zero-padding of examples with less channels if you need to collate the data. However, I suggest you modify this function to return a 3-tuple: (audio, audio_lens, channel_indexes) where audio is the collated data with shape (B, C, T), audio_lens has the length of each multi-channel example of shape (B,), and channel_indexes is a list of lists of which C dim indexes have meaningful channels for examples (it could also be channel_lens tensor of shape (B,) assuming first c channels are always meaningful, if it's possible to guarantee).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in the end your models will have to somehow deal with the non-meaningful channels anyway. As long as you're working on same-number-of-channels data, no need to overthink this.

max_ntrack = max(len(cut.tracks) for cut in cuts)
max_nsamp = max(cut.num_samples for cut in cuts)

# NOTE: this ends up zero-padding! is that appropriate?
audio = torch.zeros(len(cuts), max_ntrack, max_nsamp)
for idx, cut in enumerate(cuts):
ntrack = len(cut.tracks)
nsamp = cut.num_samples
audio[idx, 0:ntrack, 0:nsamp] = torch.from_numpy(cut.load_audio(mixed=False))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if you did cut.mix(musan_cut) here, it will also add an extra track; as is, the code would not work with additive noise data augmentation

return audio


def collate_vectors(
tensors: Iterable[Union[torch.Tensor, np.ndarray]],
padding_value: Union[int, float] = CrossEntropyLoss().ignore_index,
Expand Down
37 changes: 36 additions & 1 deletion lhotse/dataset/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lhotse import validate
from lhotse.augmentation import AugmentFn
from lhotse.cut import CutSet
from lhotse.dataset.collation import collate_audio, collate_features, collate_matrices
from lhotse.dataset.collation import collate_audio, collate_features, collate_matrices, collate_multi_channel_audio_rmpad
from lhotse.features import FeatureExtractor


Expand Down Expand Up @@ -74,6 +74,41 @@ def _validate(self, cuts: CutSet) -> None:
assert all(cut.has_recording for cut in cuts)


class UnsupervisedMultiChanWaveformDataset(UnsupervisedDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class UnsupervisedMultiChanWaveformDataset(UnsupervisedDataset):
class MultiChannelWaveformDataset(UnsupervisedDataset):

somehow reads better to me

"""
A variant of UnsupervisedDataset that provides waveform samples instead of features.
The output is a tensor of shape (C, T), with C being the number of channels and T the number of audio samples.
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'audio_lens': (B, ) int tensor
}
"""

def __init__(self, collate: bool = True) -> None:
super().__init__()
self.collate = collate

def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
if self.collate:
audio = collate_multi_channel_audio_rmpad(cuts)
audio_lens = 0 # TODO
return {
"cuts": cuts,
"audio": audio,
"audio_lens": audio_lens,
}
else:
return {"cuts": cuts, "audio": [c.load_audio(mixed=False) for c in cuts]}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line would again have the extra padding channels problem. This suggests that maybe the solution should not be (entirely) in the collate function, but inside load_audio, e.g. controlled by an extra argument?


def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)



class DynamicUnsupervisedDataset(UnsupervisedDataset):
"""
An example dataset that shows how to use on-the-fly feature extraction in Lhotse.
Expand Down