-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortformer Diarizer. #11336
Open
tango4j
wants to merge
131
commits into
NVIDIA:main
Choose a base branch
from
tango4j:sortformer/pr_02
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
131 commits
Select commit
Hold shift + click to select a range
e69ec8e
Adding the first pr files models and dataset
tango4j 2914325
Tested all unit-test files
tango4j 9a468ac
Name changes on yaml files and train example
tango4j a910d30
Merge branch 'main' into sortformer/pr_01
tango4j 2f44fe1
Apply isort and black reformatting
tango4j 4ddc59b
Reflecting comments and removing unnecessary parts for this PR
tango4j 43d95f0
Resolved conflicts
tango4j 40e9f95
Apply isort and black reformatting
tango4j f7f84bb
Adding docstrings to reflect the PR comments
tango4j 95acd79
Resolved the new conflict
tango4j 919f4da
Merge branch 'main' into sortformer/pr_01
tango4j 4134e25
removed the unused find_first_nonzero
tango4j d3432e5
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 5dd4d4c
Apply isort and black reformatting
tango4j ca5eea3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 9d493c0
Merge branch 'main' into sortformer/pr_01
tango4j 037f61e
Fixed all pylint issues
tango4j a8bc048
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j cb23268
Apply isort and black reformatting
tango4j 4a266b9
Resolving pylint issues
tango4j 5e4e9c8
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j c31c60c
Merge branch 'main' into sortformer/pr_01
tango4j 6e2225e
Apply isort and black reformatting
tango4j ab93b17
Removing unused varialbe in audio_to_diar_label.py
tango4j 4f3ee66
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 3f24b82
Merge branch 'main' into sortformer/pr_01
tango4j f49e107
Merge branch 'main' into sortformer/pr_01
tango4j 7dea01b
Fixed docstrings in training script
tango4j 2a99d53
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 71d515f
Line-too-long issue from Pylint fixed
tango4j 9b7b93e
Merge branch 'main' into sortformer/pr_01
tango4j f2d5e36
Adding get_subsegments_scriptable to prevent jit.script error
tango4j 9cca3e8
Apply isort and black reformatting
tango4j 731caa8
Merge branch 'main' into sortformer/pr_01
tango4j 681fe38
Merge branch 'main' into sortformer/pr_01
tango4j 008dcbd
Addressed Code-QL issues
tango4j d89ed91
Addressed Code-QL issues and resolved conflicts
tango4j 045f3a2
Resolved conflicts on bce_loss.py
tango4j 1dcf9ab
Apply isort and black reformatting
tango4j be8ac22
Adding all the diarization reltated unit-tests
tango4j ca44a66
Moving speaker task related unit test files to speaker_tasks folder
tango4j 1360831
Fixed uninit variable issue in bce_loss.py spotted by codeQL
tango4j 553197a
Apply isort and black reformatting
tango4j 7893e75
Merge branch 'main' into sortformer/pr_01
tango4j f7fced9
Merge branch 'main' into sortformer/pr_01
tango4j 87af813
Merge branch 'main' into sortformer/pr_02
tango4j 734dfd8
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j c3c0b32
Fixing code-QL issues
tango4j 631555d
Apply isort and black reformatting
tango4j 99ee5cc
Merge branch 'main' into sortformer/pr_02
tango4j 9371ed0
Merge branch 'main' into sortformer/pr_01
tango4j 6a3bb62
Reflecting PR comments from weiqingw
tango4j 4e0327c
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j b8a49ea
Apply isort and black reformatting
tango4j 6198a20
Line too long pylint issue resolved in e2e_diarize_speech.py
tango4j e4b0154
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 07c4242
Apply isort and black reformatting
tango4j 9feb013
Resovled unused variable issue in model test
tango4j 7496a0d
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j db90424
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j 0eeaf06
Merge branch 'main' into sortformer/pr_01
tango4j fa11155
Reflecting the comment on Nov 21st 2024.
tango4j b5878cc
Apply isort and black reformatting
tango4j bfe36e7
Merge branch 'main' into sortformer/pr_01
tango4j 7898697
Unused variable import time
tango4j e167dba
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 8712278
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j 1bb89d5
Merge branch 'main' into sortformer/pr_01
tango4j e4006cf
Adding docstrings to score_labels() function in der.py
tango4j a92e4e6
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j ca480eb
Apply isort and black reformatting
tango4j 5ea9d7d
Merge branch 'main' into sortformer/pr_01
tango4j 1b091c8
Merge branch 'main' into sortformer/pr_01
tango4j af04832
Reflecting comments on YAML files and model file variable changes.
tango4j a4367a3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j edbe159
Apply isort and black reformatting
tango4j b47579b
Merge branch 'main' into sortformer/pr_01
tango4j 8365a05
Added get_subsegments_scriptable for legacy get_subsegment functions
tango4j f2250a0
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 5275fb5
Merge branch 'main' into sortformer/pr_01
tango4j 86315db
Apply isort and black reformatting
tango4j 07f791a
Resolved line too long pylint issues
tango4j 2b23136
Resolved line too long pylint issues and merged main
tango4j 30f1159
Apply isort and black reformatting
tango4j f9a9884
Merge branch 'main' into sortformer/pr_01
tango4j 0e50abf
Merge branch 'main' into sortformer/pr_01
tango4j f232a40
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j 6fd3076
Merge branch 'main' into sortformer/pr_01
tango4j 7ec3b1f
Added training and inference CI-tests
tango4j 0eb260e
Added the missing parse_func in preprocessing/collections.py
tango4j 895b4ed
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 0d6ebc7
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j 37d4240
Adding the missing parse_func in preprocessing/collections.py
tango4j 01085ab
Merge branch 'main' into sortformer/pr_01
tango4j 03c425b
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j bde6887
Fixed an indentation error
tango4j 3f378f6
Merge branch 'main' into sortformer/pr_02
tango4j 470579d
Merge branch 'main' into sortformer/pr_01
tango4j 024a391
Merge branch 'main' into sortformer/pr_02
tango4j 73944e3
Resolved multi_bin_acc and bce_loss issues
tango4j ea4c2a7
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j 586c64c
Merge branch 'main' into sortformer/pr_01
tango4j 567f927
Resolved line-too-long for msdd_models.py
tango4j 4c4eb1e
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j b503d3e
Apply isort and black reformatting
tango4j fa2a663
Merge branch 'main' into sortformer/pr_01
tango4j 81b751e
Merge branch 'main' into sortformer/pr_02
tango4j f82d9c7
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j f7029d7
Merge branch 'main' into sortformer/pr_02
tango4j 2c6eed7
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j f469e72
Merging main and resolving conflicts
tango4j f5a9c47
Code QL issues and fixed test errors
tango4j ee258f7
Apply isort and black reformatting
tango4j 3781604
line too long in audio_to_diar_label.py
tango4j 4c19278
line too long in audio_to_diar_label.py
tango4j 64eaefd
Apply isort and black reformatting
tango4j 38e52b5
resolving CICD test issues
tango4j 10062c7
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j 4988c4d
Merge branch 'main' into sortformer/pr_02
tango4j 380e3e9
Merge branch 'NVIDIA:main' into sortformer/pr_02
tango4j e1db4a5
Merge branch 'main' into sortformer/pr_02
tango4j dd6a097
Fixing codeQL issues
tango4j 80f15cb
Merge branch 'main' into sortformer/pr_02
tango4j 2d05f6a
Merge branch 'main' into sortformer/pr_02
tango4j e13d5fc
Fixed pin memory False for inference
tango4j 4dfaa82
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j 85c0d9f
Merge branch 'main' into sortformer/pr_02
tango4j 86fb5a0
Resolved the device mismatch in get_ats_targets
tango4j e753a3b
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j 23164d6
Apply isort and black reformatting
tango4j c71a713
Merge branch 'main' into sortformer/pr_02
tango4j File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
import torch.cuda | ||
|
||
from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset | ||
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer | ||
from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines | ||
|
||
|
||
def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec): | ||
""" | ||
Check if the maximum RTTM duration exceeds the length of the provided audio file. | ||
|
||
Args: | ||
rttm_file_path (str): Path to the RTTM file. | ||
wav_len_in_sec (float): Length of the audio file in seconds. | ||
|
||
Returns: | ||
bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise. | ||
""" | ||
rttm_lines = read_rttm_lines(rttm_file_path) | ||
max_rttm_sec = 0 | ||
for line in rttm_lines: | ||
start, dur = get_vad_out_from_rttm_line(line) | ||
max_rttm_sec = max(max_rttm_sec, start + dur) | ||
return max_rttm_sec <= wav_len_in_sec | ||
|
||
|
||
class TestAudioToSpeechE2ESpkDiarDataset: | ||
|
||
@pytest.mark.unit | ||
def test_e2e_speaker_diar_dataset(self, test_data_dir): | ||
|
||
manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json')) | ||
|
||
batch_size = 4 | ||
num_samples = 8 | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
data_dict_list = [] | ||
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: | ||
with open(manifest_path, 'r', encoding='utf-8') as mfile: | ||
for ix, line in enumerate(mfile): | ||
if ix >= num_samples: | ||
break | ||
|
||
line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "") | ||
f.write(f"{line}\n") | ||
data_dict = json.loads(line) | ||
data_dict_list.append(data_dict) | ||
|
||
f.seek(0) | ||
featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None) | ||
|
||
dataset = AudioToSpeechE2ESpkDiarDataset( | ||
manifest_filepath=f.name, | ||
soft_label_thres=0.5, | ||
session_len_sec=90, | ||
num_spks=4, | ||
featurizer=featurizer, | ||
window_stride=0.01, | ||
global_rank=0, | ||
soft_targets=False, | ||
device=device, | ||
) | ||
dataloader_instance = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
collate_fn=dataset.eesd_train_collate_fn, | ||
drop_last=False, | ||
shuffle=False, | ||
num_workers=1, | ||
pin_memory=False, | ||
) | ||
assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct | ||
batch_counts = len(dataloader_instance) | ||
|
||
deviation_thres_rate = 0.01 # 1% deviation allowed | ||
for batch_index, batch in enumerate(dataloader_instance): | ||
if batch_index != batch_counts - 1: | ||
assert len(batch) == batch_size, "Batch size does not match the expected value" | ||
audio_signals, audio_signal_len, targets, target_lens = batch | ||
for sample_index in range(audio_signals.shape[0]): | ||
dataloader_audio_in_sec = audio_signal_len[sample_index].item() | ||
data_dur_in_sec = abs( | ||
data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate | ||
- dataloader_audio_in_sec | ||
) | ||
assert ( | ||
data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec | ||
), "Duration deviation exceeds 1%" | ||
assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values" | ||
assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values" | ||
assert not torch.isnan(targets).any(), "targets tensor contains NaN values" | ||
assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a bit difficult to understand when the lists (e.g.,
self.preds_total_list
,self.batch_f1_accs_list
) are initialized in this function but are updated inside theself. _get_aux_test_batch_evaluations
, but not a strong opinion~