Skip to content

Commit

Permalink
Sortformer Diarizer 4spk v1 model PR Part 3: Speaker Diarization Mixin (
Browse files Browse the repository at this point in the history
#11511)

* Adding diarization mixin for one click inference

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Resolving CodeQL and Pylint

Signed-off-by: taejinp <[email protected]>

* Resolving CodeQL and Pylint - unsaved files resolved

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Unused package manifest_utils

Signed-off-by: taejinp <[email protected]>

* Resolved diarization mixin test issues

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Removed commented lines

Signed-off-by: taejinp <[email protected]>

* updating mixins code

Signed-off-by: ipmedenn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ipmedenn <[email protected]>

* fixing test_diarizartion.py

Signed-off-by: ipmedenn <[email protected]>

* moving diarization postprocessing-related stuff to vad_utils.py

Signed-off-by: ipmedenn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ipmedenn <[email protected]>

* Resolving PyLint

Signed-off-by: ipmedenn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ipmedenn <[email protected]>

* fixing batch_idx issue in sortformer_diar_models.py

Signed-off-by: ipmedenn <[email protected]>

* adding sync_dist=True for sortformer validation

Signed-off-by: ipmedenn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Reflecting the comments from PR

Signed-off-by: taejinp <[email protected]>

* Reflecting the comments from PR 2nd

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

* Resolved a codeQL unused variable

Signed-off-by: taejinp <[email protected]>

* Now moved existance check after

Signed-off-by: taejinp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tango4j <[email protected]>

---------

Signed-off-by: taejinp <[email protected]>
Signed-off-by: tango4j <[email protected]>
Signed-off-by: ipmedenn <[email protected]>
Signed-off-by: ipmedenn <[email protected]>
Co-authored-by: tango4j <[email protected]>
Co-authored-by: ipmedenn <[email protected]>
Co-authored-by: ipmedenn <[email protected]>
  • Loading branch information
4 people authored and yashaswikarnati committed Dec 20, 2024
1 parent c44f092 commit be5d440
Show file tree
Hide file tree
Showing 7 changed files with 1,108 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Usage for diarization inference:
The end-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name".
The end-to-end speaker diarization model can be specified by "model_path".
Data for diarization is fed through the "dataset_manifest".
By default, post-processing is bypassed, and only binarization is performed.
If you want to reproduce DER scores reported on NeMo model cards, you need to apply post-processing steps.
Expand All @@ -45,45 +45,32 @@
import lightning.pytorch as pl
import optuna
import torch
import yaml
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from tqdm import tqdm

from nemo.collections.asr.metrics.der import score_labels
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object
from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing
from nemo.collections.asr.parts.utils.speaker_utils import (
audio_rttm_map,
get_uniqname_from_filepath,
timestamps_to_pyannote_object,
)
from nemo.collections.asr.parts.utils.vad_utils import (
PostProcessingParams,
load_postprocessing_from_yaml,
predlist_to_timestamps,
)
from nemo.core.config import hydra_runner

seed_everything(42)
torch.backends.cudnn.deterministic = True


@dataclass
class PostProcessingParams:
"""
Postprocessing parameters for end-to-end speaker diarization models.
These parameters can significantly affect DER performance depending on the evaluation style and the dataset.
It is recommended to tune these parameters based on the evaluation style and the dataset
to achieve the desired DER performance.
"""

onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech
offset: float = 0.5 # Offset threshold for detecting the end of a speech
pad_onset: float = 0.0 # Adding durations before each speech segment
pad_offset: float = 0.0 # Adding durations after each speech segment
min_duration_on: float = 0.0 # Threshold for small non-speech deletion
min_duration_off: float = 0.0 # Threshold for short speech segment deletion


@dataclass
class DiarizationConfig:
"""Diarization configuration parameters for inference."""

model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
audio_dir: Optional[str] = None # Path to a directory which contains audio files
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest

postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations
Expand Down Expand Up @@ -114,36 +101,6 @@ class DiarizationConfig:
optuna_n_trials: int = 100000


def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams = None) -> PostProcessingParams:
"""
Load postprocessing parameters from a YAML file.
Args:
postprocessing_yaml (str):
Path to a YAML file for postprocessing configurations.
Returns:
postprocessing_params (dataclass):
Postprocessing parameters loaded from the YAML file.
"""
# Add PostProcessingParams as a field
postprocessing_params = OmegaConf.structured(PostProcessingParams())
if postprocessing_yaml is None:
logging.info(
f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied."
)
else:
# Load postprocessing params from the provided YAML file
with open(postprocessing_yaml, 'r') as file:
yaml_params = yaml.safe_load(file)['parameters']
# Update the postprocessing_params with the loaded values
logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.")
for key, value in yaml_params.items():
if hasattr(postprocessing_params, key):
setattr(postprocessing_params, key, value)
return postprocessing_params


def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams:
"""
Suggests hyperparameters for postprocessing using Optuna.
Expand Down Expand Up @@ -303,26 +260,19 @@ def convert_pred_mat_to_segments(
"""
batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], []
cfg_vad_params = OmegaConf.structured(postprocessing_cfg)
pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing"
for sample_idx, (uniq_id, audio_rttm_values) in tqdm(
enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message
):
spk_ts = []
offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration']
speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0)
speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])]
for spk_id in range(speaker_assign_mat.shape[-1]):
ts_mat = ts_vad_post_processing(
speaker_assign_mat[:, spk_id],
cfg_vad_params=cfg_vad_params,
unit_10ms_frame_count=unit_10ms_frame_count,
bypass_postprocessing=bypass_postprocessing,
)
ts_mat = ts_mat + offset
ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration))
ts_seg_list = ts_mat.tolist()
speaker_timestamps[spk_id].extend(ts_seg_list)
spk_ts.append(ts_seg_list)
total_speaker_timestamps = predlist_to_timestamps(
batch_preds_list=batch_preds_list,
audio_rttm_map_dict=audio_rttm_map_dict,
cfg_vad_params=cfg_vad_params,
unit_10ms_frame_count=unit_10ms_frame_count,
bypass_postprocessing=bypass_postprocessing,
)
for sample_idx, (uniq_id, audio_rttm_values) in enumerate(audio_rttm_map_dict.items()):
speaker_timestamps = total_speaker_timestamps[sample_idx]
if audio_rttm_values.get("uniq_id", None) is not None:
uniq_id = audio_rttm_values["uniq_id"]
else:
uniq_id = get_uniqname_from_filepath(audio_rttm_values["audio_filepath"])
all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(
speaker_timestamps,
uniq_id,
Expand All @@ -332,7 +282,6 @@ def convert_pred_mat_to_segments(
all_uems,
out_rttm_dir,
)
batch_pred_ts_segs.append(spk_ts)
return all_hypothesis, all_reference, all_uems


Expand All @@ -348,10 +297,8 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
if cfg.model_path is None:
raise ValueError("cfg.model_path cannot be None. Please specify the path to the model.")

# setup GPU
torch.set_float32_matmul_precision(cfg.matmul_precision)
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,11 @@ def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_le
Example of seg_target:
[[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]]
"""
if rttm_file in [None, '']:
num_seg = torch.max(target_len)
targets = torch.zeros(num_seg, self.max_spks)
return targets

with open(rttm_file, 'r') as f:
rttm_lines = f.readlines()

Expand Down
Loading

0 comments on commit be5d440

Please sign in to comment.