Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sortformer Diarizer 4spk v1 model PR Part 3: Speaker Diarization Mixin #11511

Merged
merged 48 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
018df0b
Adding diarization mixin for one click inference
tango4j Dec 6, 2024
5ad4adf
Merge branch 'NVIDIA:main' into sortformer/pr_mixin
tango4j Dec 9, 2024
27a7e72
Apply isort and black reformatting
tango4j Dec 9, 2024
ab83467
Resolving CodeQL and Pylint
tango4j Dec 9, 2024
fe7ac1b
Resolving CodeQL and Pylint
tango4j Dec 9, 2024
046bf8c
Resolving CodeQL and Pylint - unsaved files resolved
tango4j Dec 9, 2024
4a52681
Apply isort and black reformatting
tango4j Dec 9, 2024
86374c0
Unused package manifest_utils
tango4j Dec 9, 2024
59070f7
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 9, 2024
4ae6f75
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 9, 2024
b7b5e4a
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 10, 2024
c2f8193
Resolved diarization mixin test issues
tango4j Dec 10, 2024
e838746
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 10, 2024
7e57ce5
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 10, 2024
9addedb
Apply isort and black reformatting
tango4j Dec 10, 2024
56f41cc
Removed commented lines
tango4j Dec 10, 2024
b1b8dd9
updating mixins code
ipmedenn Dec 12, 2024
17dca5e
Apply isort and black reformatting
ipmedenn Dec 12, 2024
6eb151b
fixing test_diarizartion.py
ipmedenn Dec 12, 2024
b2da316
moving diarization postprocessing-related stuff to vad_utils.py
ipmedenn Dec 12, 2024
383d005
Apply isort and black reformatting
ipmedenn Dec 12, 2024
cd1bcbd
Resolving PyLint
ipmedenn Dec 13, 2024
7965429
Apply isort and black reformatting
ipmedenn Dec 13, 2024
43f3ad3
Merge remote-tracking branch 'origin/main' into sortformer/pr_mixin
ipmedenn Dec 13, 2024
7c54189
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 14, 2024
00be40c
fixing batch_idx issue in sortformer_diar_models.py
ipmedenn Dec 16, 2024
2317a7f
adding sync_dist=True for sortformer validation
ipmedenn Dec 16, 2024
716864e
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 16, 2024
b66858e
Adding changes on model file
tango4j Dec 17, 2024
a0c10ff
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 17, 2024
3434e93
Apply isort and black reformatting
tango4j Dec 17, 2024
2e7d64a
Removed unused pretrained_name and audio_dir in e2e_diarize_speech
tango4j Dec 17, 2024
e7595b4
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 17, 2024
db3100f
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 17, 2024
ea79127
Reflecting the comments from PR
tango4j Dec 17, 2024
8499c1c
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 17, 2024
c8eab9c
Reflecting the comments from PR 2nd
tango4j Dec 17, 2024
621698b
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 17, 2024
b87e6ce
Apply isort and black reformatting
tango4j Dec 17, 2024
a15d24f
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 17, 2024
40a2fd4
Resolved a codeQL unused variable
tango4j Dec 17, 2024
2b9b72f
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 17, 2024
739a5d5
Now moved existance check after
tango4j Dec 18, 2024
346fc23
Apply isort and black reformatting
tango4j Dec 18, 2024
46b01d5
Removed the long lines
tango4j Dec 18, 2024
1ce781a
Merge branch 'sortformer/pr_mixin' of https://github.com/tango4j/NeMo…
tango4j Dec 18, 2024
07e7199
Merge branch 'main' into sortformer/pr_mixin
tango4j Dec 18, 2024
7d4d10a
Apply isort and black reformatting
tango4j Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,35 +48,22 @@
import yaml
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from tqdm import tqdm

from nemo.collections.asr.metrics.der import score_labels
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object
from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing
from nemo.collections.asr.models.sortformer_diar_models import PostProcessingParams
from nemo.collections.asr.parts.utils.speaker_utils import (
audio_rttm_map,
get_uniqname_from_filepath,
timestamps_to_pyannote_object,
)
from nemo.collections.asr.parts.utils.vad_utils import predlist_to_timestamps
from nemo.core.config import hydra_runner

seed_everything(42)
torch.backends.cudnn.deterministic = True


@dataclass
class PostProcessingParams:
"""
Postprocessing parameters for end-to-end speaker diarization models.
These parameters can significantly affect DER performance depending on the evaluation style and the dataset.
It is recommended to tune these parameters based on the evaluation style and the dataset
to achieve the desired DER performance.
"""

onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech
offset: float = 0.5 # Offset threshold for detecting the end of a speech
pad_onset: float = 0.0 # Adding durations before each speech segment
pad_offset: float = 0.0 # Adding durations after each speech segment
min_duration_on: float = 0.0 # Threshold for small non-speech deletion
min_duration_off: float = 0.0 # Threshold for short speech segment deletion


@dataclass
class DiarizationConfig:
"""Diarization configuration parameters for inference."""
Expand Down Expand Up @@ -303,26 +290,19 @@ def convert_pred_mat_to_segments(
"""
batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], []
cfg_vad_params = OmegaConf.structured(postprocessing_cfg)
pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing"
for sample_idx, (uniq_id, audio_rttm_values) in tqdm(
enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message
):
spk_ts = []
offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration']
speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0)
speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])]
for spk_id in range(speaker_assign_mat.shape[-1]):
ts_mat = ts_vad_post_processing(
speaker_assign_mat[:, spk_id],
cfg_vad_params=cfg_vad_params,
unit_10ms_frame_count=unit_10ms_frame_count,
bypass_postprocessing=bypass_postprocessing,
)
ts_mat = ts_mat + offset
ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration))
ts_seg_list = ts_mat.tolist()
speaker_timestamps[spk_id].extend(ts_seg_list)
spk_ts.append(ts_seg_list)
total_speaker_timestamps = predlist_to_timestamps(
batch_preds_list=batch_preds_list,
audio_rttm_map_dict=audio_rttm_map_dict,
cfg_vad_params=cfg_vad_params,
unit_10ms_frame_count=unit_10ms_frame_count,
bypass_postprocessing=bypass_postprocessing,
)
for sample_idx, (uniq_id, audio_rttm_values) in enumerate(audio_rttm_map_dict.items()):
speaker_timestamps = total_speaker_timestamps[sample_idx]
if audio_rttm_values.get("uniq_id", None) is not None:
uniq_id = audio_rttm_values["uniq_id"]
else:
uniq_id = get_uniqname_from_filepath(audio_rttm_values["audio_filepath"])
all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(
speaker_timestamps,
uniq_id,
Expand All @@ -332,7 +312,6 @@ def convert_pred_mat_to_segments(
all_uems,
out_rttm_dir,
)
batch_pred_ts_segs.append(spk_ts)
return all_hypothesis, all_reference, all_uems


Expand Down
17 changes: 16 additions & 1 deletion nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ def __init__(
round_digits: int = 2,
soft_targets: bool = False,
subsampling_factor: int = 8,
device: str = 'cpu',
):
super().__init__()
self.collection = EndtoEndDiarizationSpeechLabel(
Expand All @@ -1084,6 +1085,7 @@ def __init__(
self.soft_targets = soft_targets
self.round_digits = 2
self.floor_decimal = 10**self.round_digits
self.device = device

def __len__(self):
return len(self.collection)
Expand Down Expand Up @@ -1118,6 +1120,13 @@ def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_le
Example of seg_target:
[[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
"""
if rttm_file in [None, '']:
# return torch.zeros((target_len, self.max_spks), dtype=self.dtype)
num_seg = torch.max(target_len)
targets = torch.zeros(num_seg, self.max_spks)
return targets
# return torch.zeros(target_len, self.max_spks)
# return None
with open(rttm_file, 'r') as f:
rttm_lines = f.readlines()

Expand Down Expand Up @@ -1232,11 +1241,15 @@ def __getitem__(self, index):
audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)]

audio_signal_length = torch.tensor(audio_signal.shape[0]).long()
audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu')
# audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device)
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate)
# .to(
# self.device
# )
targets = self.parse_rttm_for_targets_and_lens(
rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len
)
# .to(self.device)
return audio_signal, audio_signal_length, targets, target_len


Expand Down Expand Up @@ -1355,6 +1368,7 @@ def __init__(
window_stride,
global_rank: int,
soft_targets: bool,
device: str,
):
super().__init__(
manifest_filepath=manifest_filepath,
Expand All @@ -1365,6 +1379,7 @@ def __init__(
window_stride=window_stride,
global_rank=global_rank,
soft_targets=soft_targets,
device=device,
)

def eesd_train_collate_fn(self, batch):
Expand Down
115 changes: 107 additions & 8 deletions nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,58 @@
# limitations under the License.

import itertools
import os
import random
from collections import OrderedDict
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from tqdm import tqdm

from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset
from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset
from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
from nemo.collections.asr.parts.mixins.diarization import DiarizeConfig, SpkDiarizationMixin
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets
from nemo.collections.asr.parts.utils.vad_utils import predlist_to_timestamps
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
from nemo.core.neural_types.elements import ProbsType
from nemo.utils import logging
from nemo.collections.asr.parts.mixins.diarization import SpkDiarizationMixin

__all__ = ['SortformerEncLabelModel']


class SortformerEncLabelModel(ModelPT, ExportableEncDecModel):
@dataclass
class PostProcessingParams:
"""
Postprocessing parameters for end-to-end speaker diarization models.
These parameters can significantly affect DER performance depending on the evaluation style and the dataset.
It is recommended to tune these parameters based on the evaluation style and the dataset
to achieve the desired DER performance.
"""

onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech
offset: float = 0.5 # Offset threshold for detecting the end of a speech
pad_onset: float = 0.0 # Adding durations before each speech segment
pad_offset: float = 0.0 # Adding durations after each speech segment
min_duration_on: float = 0.0 # Threshold for small non-speech deletion
min_duration_off: float = 0.0 # Threshold for short speech segment deletion


class SortformerEncLabelModel(ModelPT, ExportableEncDecModel, SpkDiarizationMixin):
"""
Encoder class for Sortformer diarization model.
Model class creates training, validation methods for setting up data performing model forward pass.
Expand Down Expand Up @@ -108,7 +132,6 @@
self.streaming_mode = self._cfg.get("streaming_mode", False)
self.save_hyperparameters("cfg")
self._init_eval_metrics()

speaker_inds = list(range(self._cfg.max_num_of_spks))
self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations

Expand All @@ -119,7 +142,6 @@
raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0")
self.pil_weight = pil_weight / (pil_weight + ats_weight)
self.ats_weight = ats_weight / (pil_weight + ats_weight)
logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}")

def _init_eval_metrics(self):
"""
Expand Down Expand Up @@ -175,6 +197,7 @@
window_stride=self._cfg.preprocessor.window_stride,
global_rank=global_rank,
soft_targets=config.soft_targets if 'soft_targets' in config else False,
device=self.device,
)

self.data_collection = dataset.collection
Expand Down Expand Up @@ -268,6 +291,57 @@
preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq)
return preds

def _diarize_forward(self, batch: Any):
"""
A counterpart of `_transcribe_forward` function in ASR.
This function is a wrapper for forward pass functions for compataibility
with the existing classes.

Args:
batch (Any): The input batch containing audio signal and audio signal length.
diarcfg (DiarizeConfig): The configuration for diarization.

Returns:
preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels.
Shape: (batch_size, diar_frame_count, num_speakers)
"""
with torch.no_grad():
preds = self.forward(audio_signal=batch[0], audio_signal_length=batch[1])
return preds

def _setup_diarize_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Setup function for a temporary data loader which wraps the provided audio file.

Args:
config: A python dictionary which contains the following keys:
- manifest_filepath: Path to the manifest file containing audio file paths
and corresponding speaker labels.

Returns:
A pytorch DataLoader for the given audio file(s).
"""
if 'manifest_filepath' in config:
manifest_filepath = config['manifest_filepath']
batch_size = config['batch_size']
else:
manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
batch_size = min(config['batch_size'], len(config['paths2audio_files']))

dl_config = {
'manifest_filepath': manifest_filepath,
'sample_rate': self.preprocessor._sample_rate,
'num_spks': config.get('num_spks', self._cfg.max_num_of_spks),
'batch_size': batch_size,
'shuffle': False,
'soft_label_thres': 0.5,
'session_len_sec': config['session_len_sec'],
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
'pin_memory': True,
}
temporary_datalayer = self.__setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

def process_signal(self, audio_signal, audio_signal_length):
"""
Extract audio features from time-series signal for further processing in the model.
Expand All @@ -290,7 +364,7 @@
- processed_signal_length (torch.Tensor): The length of each processed signal.
Shape: (batch_size,)
"""
audio_signal = audio_signal.to(self.device)
audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device)
audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal
processed_signal, processed_signal_length = self.preprocessor(
input_signal=audio_signal, length=audio_signal_length
Expand Down Expand Up @@ -572,8 +646,33 @@

def diarize(
self,
audio: Union[str, List[str], np.ndarray, DataLoader],
batch_size: int = 1,
verbose: bool = False,
num_workers: int = 0,
bypass_postprocessing: bool = True,
postprocessing_config=None,
unit_10ms_frame_count: int = 8,
):
"""One-clieck runner function for diarization."""
"""One-click runner function for diarization."""
# TODO: A direct one-click runner function that generates
# speaker labels from audio file path lists.
raise NotImplementedError

batch_preds_list = super().diarize(
audio=audio,
batch_size=batch_size,
verbose=verbose,
num_workers=num_workers,
)

pp_params = OmegaConf.structured(PostProcessingParams())
postprocessing_config = pp_params if postprocessing_config is None else postprocessing_config
total_speaker_timestamps = predlist_to_timestamps(
batch_preds_list=batch_preds_list,
audio_rttm_map_dict=self._diarize_audio_rttm_map,
cfg_vad_params=postprocessing_config,
unit_10ms_frame_count=int(self._cfg.encoder.subsampling_factor),
bypass_postprocessing=bypass_postprocessing,
precision=2,
)
return total_speaker_timestamps
Loading
Loading