Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortformer Diarizer. #11336

Open
wants to merge 131 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
e69ec8e
Adding the first pr files models and dataset
tango4j Nov 14, 2024
2914325
Tested all unit-test files
tango4j Nov 14, 2024
9a468ac
Name changes on yaml files and train example
tango4j Nov 14, 2024
a910d30
Merge branch 'main' into sortformer/pr_01
tango4j Nov 14, 2024
2f44fe1
Apply isort and black reformatting
tango4j Nov 14, 2024
4ddc59b
Reflecting comments and removing unnecessary parts for this PR
tango4j Nov 15, 2024
43d95f0
Resolved conflicts
tango4j Nov 15, 2024
40e9f95
Apply isort and black reformatting
tango4j Nov 15, 2024
f7f84bb
Adding docstrings to reflect the PR comments
tango4j Nov 15, 2024
95acd79
Resolved the new conflict
tango4j Nov 15, 2024
919f4da
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
4134e25
removed the unused find_first_nonzero
tango4j Nov 15, 2024
d3432e5
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
5dd4d4c
Apply isort and black reformatting
tango4j Nov 15, 2024
ca5eea3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
9d493c0
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
037f61e
Fixed all pylint issues
tango4j Nov 15, 2024
a8bc048
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
cb23268
Apply isort and black reformatting
tango4j Nov 15, 2024
4a266b9
Resolving pylint issues
tango4j Nov 15, 2024
5e4e9c8
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
c31c60c
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
6e2225e
Apply isort and black reformatting
tango4j Nov 15, 2024
ab93b17
Removing unused varialbe in audio_to_diar_label.py
tango4j Nov 15, 2024
4f3ee66
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
3f24b82
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
f49e107
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
7dea01b
Fixed docstrings in training script
tango4j Nov 16, 2024
2a99d53
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 16, 2024
71d515f
Line-too-long issue from Pylint fixed
tango4j Nov 16, 2024
9b7b93e
Merge branch 'main' into sortformer/pr_01
tango4j Nov 18, 2024
f2d5e36
Adding get_subsegments_scriptable to prevent jit.script error
tango4j Nov 19, 2024
9cca3e8
Apply isort and black reformatting
tango4j Nov 19, 2024
731caa8
Merge branch 'main' into sortformer/pr_01
tango4j Nov 19, 2024
681fe38
Merge branch 'main' into sortformer/pr_01
tango4j Nov 19, 2024
008dcbd
Addressed Code-QL issues
tango4j Nov 19, 2024
d89ed91
Addressed Code-QL issues and resolved conflicts
tango4j Nov 19, 2024
045f3a2
Resolved conflicts on bce_loss.py
tango4j Nov 19, 2024
1dcf9ab
Apply isort and black reformatting
tango4j Nov 19, 2024
be8ac22
Adding all the diarization reltated unit-tests
tango4j Nov 19, 2024
ca44a66
Moving speaker task related unit test files to speaker_tasks folder
tango4j Nov 20, 2024
1360831
Fixed uninit variable issue in bce_loss.py spotted by codeQL
tango4j Nov 20, 2024
553197a
Apply isort and black reformatting
tango4j Nov 20, 2024
7893e75
Merge branch 'main' into sortformer/pr_01
tango4j Nov 20, 2024
f7fced9
Merge branch 'main' into sortformer/pr_01
tango4j Nov 20, 2024
87af813
Merge branch 'main' into sortformer/pr_02
tango4j Nov 20, 2024
734dfd8
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 20, 2024
c3c0b32
Fixing code-QL issues
tango4j Nov 21, 2024
631555d
Apply isort and black reformatting
tango4j Nov 21, 2024
99ee5cc
Merge branch 'main' into sortformer/pr_02
tango4j Nov 21, 2024
9371ed0
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
6a3bb62
Reflecting PR comments from weiqingw
tango4j Nov 21, 2024
4e0327c
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
b8a49ea
Apply isort and black reformatting
tango4j Nov 21, 2024
6198a20
Line too long pylint issue resolved in e2e_diarize_speech.py
tango4j Nov 21, 2024
e4b0154
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
07c4242
Apply isort and black reformatting
tango4j Nov 21, 2024
9feb013
Resovled unused variable issue in model test
tango4j Nov 21, 2024
7496a0d
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
db90424
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 21, 2024
0eeaf06
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
fa11155
Reflecting the comment on Nov 21st 2024.
tango4j Nov 21, 2024
b5878cc
Apply isort and black reformatting
tango4j Nov 21, 2024
bfe36e7
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
7898697
Unused variable import time
tango4j Nov 21, 2024
e167dba
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 21, 2024
8712278
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 21, 2024
1bb89d5
Merge branch 'main' into sortformer/pr_01
tango4j Nov 21, 2024
e4006cf
Adding docstrings to score_labels() function in der.py
tango4j Nov 22, 2024
a92e4e6
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
ca480eb
Apply isort and black reformatting
tango4j Nov 22, 2024
5ea9d7d
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
1b091c8
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
af04832
Reflecting comments on YAML files and model file variable changes.
tango4j Nov 22, 2024
a4367a3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
edbe159
Apply isort and black reformatting
tango4j Nov 22, 2024
b47579b
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
8365a05
Added get_subsegments_scriptable for legacy get_subsegment functions
tango4j Nov 22, 2024
f2250a0
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 22, 2024
5275fb5
Merge branch 'main' into sortformer/pr_01
tango4j Nov 22, 2024
86315db
Apply isort and black reformatting
tango4j Nov 22, 2024
07f791a
Resolved line too long pylint issues
tango4j Nov 22, 2024
2b23136
Resolved line too long pylint issues and merged main
tango4j Nov 22, 2024
30f1159
Apply isort and black reformatting
tango4j Nov 22, 2024
f9a9884
Merge branch 'main' into sortformer/pr_01
tango4j Nov 23, 2024
0e50abf
Merge branch 'main' into sortformer/pr_01
tango4j Nov 24, 2024
f232a40
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 25, 2024
6fd3076
Merge branch 'main' into sortformer/pr_01
tango4j Nov 26, 2024
7ec3b1f
Added training and inference CI-tests
tango4j Nov 26, 2024
0eb260e
Added the missing parse_func in preprocessing/collections.py
tango4j Nov 26, 2024
895b4ed
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 26, 2024
0d6ebc7
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 26, 2024
37d4240
Adding the missing parse_func in preprocessing/collections.py
tango4j Nov 26, 2024
01085ab
Merge branch 'main' into sortformer/pr_01
tango4j Nov 26, 2024
03c425b
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 26, 2024
bde6887
Fixed an indentation error
tango4j Nov 26, 2024
3f378f6
Merge branch 'main' into sortformer/pr_02
tango4j Nov 26, 2024
470579d
Merge branch 'main' into sortformer/pr_01
tango4j Nov 26, 2024
024a391
Merge branch 'main' into sortformer/pr_02
tango4j Nov 26, 2024
73944e3
Resolved multi_bin_acc and bce_loss issues
tango4j Nov 27, 2024
ea4c2a7
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 27, 2024
586c64c
Merge branch 'main' into sortformer/pr_01
tango4j Nov 27, 2024
567f927
Resolved line-too-long for msdd_models.py
tango4j Nov 27, 2024
4c4eb1e
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 27, 2024
b503d3e
Apply isort and black reformatting
tango4j Nov 27, 2024
fa2a663
Merge branch 'main' into sortformer/pr_01
tango4j Nov 27, 2024
81b751e
Merge branch 'main' into sortformer/pr_02
tango4j Nov 27, 2024
f82d9c7
Merge remote-tracking branch 'origin/sortformer/pr_01' into sortforme…
tango4j Nov 27, 2024
f7029d7
Merge branch 'main' into sortformer/pr_02
tango4j Nov 27, 2024
2c6eed7
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Nov 29, 2024
f469e72
Merging main and resolving conflicts
tango4j Nov 29, 2024
f5a9c47
Code QL issues and fixed test errors
tango4j Nov 29, 2024
ee258f7
Apply isort and black reformatting
tango4j Nov 29, 2024
3781604
line too long in audio_to_diar_label.py
tango4j Nov 29, 2024
4c19278
line too long in audio_to_diar_label.py
tango4j Nov 29, 2024
64eaefd
Apply isort and black reformatting
tango4j Nov 29, 2024
38e52b5
resolving CICD test issues
tango4j Dec 2, 2024
10062c7
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Dec 2, 2024
4988c4d
Merge branch 'main' into sortformer/pr_02
tango4j Dec 2, 2024
380e3e9
Merge branch 'NVIDIA:main' into sortformer/pr_02
tango4j Dec 2, 2024
e1db4a5
Merge branch 'main' into sortformer/pr_02
tango4j Dec 3, 2024
dd6a097
Fixing codeQL issues
tango4j Dec 5, 2024
80f15cb
Merge branch 'main' into sortformer/pr_02
tango4j Dec 5, 2024
2d05f6a
Merge branch 'main' into sortformer/pr_02
tango4j Dec 6, 2024
e13d5fc
Fixed pin memory False for inference
tango4j Dec 6, 2024
4dfaa82
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Dec 6, 2024
85c0d9f
Merge branch 'main' into sortformer/pr_02
tango4j Dec 6, 2024
86fb5a0
Resolved the device mismatch in get_ats_targets
tango4j Dec 8, 2024
e753a3b
Merge branch 'sortformer/pr_02' of https://github.com/tango4j/NeMo in…
tango4j Dec 8, 2024
23164d6
Apply isort and black reformatting
tango4j Dec 8, 2024
c71a713
Merge branch 'main' into sortformer/pr_02
tango4j Dec 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,33 @@ jobs:
+trainer.fast_dev_run=True \
exp_manager.exp_dir=/tmp/speaker_diarization_results

L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py \
trainer.devices="[0]" \
batch_size=3 \
model.train_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_train/eesd_train_tiny.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
exp_manager.exp_dir=/tmp/speaker_diarization_results \
+trainer.fast_dev_run=True

L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py \
model_path=/home/TestData/an4_diarizer/diar_sortformer_4spk-v1-tiny.nemo \
dataset_manifest=/home/TestData/an4_diarizer/simulated_valid/eesd_valid_tiny.json \
batch_size=1

L2_Speaker_dev_run_Speech_to_Label:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4586,6 +4613,8 @@ jobs:
- L2_Speech_to_Text_EMA
- L2_Speaker_dev_run_Speaker_Recognition
- L2_Speaker_dev_run_Speaker_Diarization
- L2_Speaker_dev_run_EndtoEnd_Speaker_Diarization_Sortformer
- L2_Speaker_dev_run_EndtoEnd_Diarizer_Inference
- L2_Speaker_dev_run_Speech_to_Label
- L2_Speaker_dev_run_Speaker_Diarization_with_ASR_Inference
- L2_Speaker_dev_run_Clustering_Diarizer_Inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest
infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest)
diar_model._cfg.test_ds.batch_size = cfg.batch_size
diar_model._cfg.test_ds.pin_memory = False

# Model setup for inference
diar_model._cfg.test_ds.num_workers = cfg.num_workers
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ def __init__(
round_digits: int = 2,
soft_targets: bool = False,
subsampling_factor: int = 8,
device: str = 'cpu',
):
super().__init__()
self.collection = EndtoEndDiarizationSpeechLabel(
Expand All @@ -1084,6 +1085,7 @@ def __init__(
self.soft_targets = soft_targets
self.round_digits = 2
self.floor_decimal = 10**self.round_digits
self.device = device

def __len__(self):
return len(self.collection)
Expand Down Expand Up @@ -1232,11 +1234,13 @@ def __getitem__(self, index):
audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)]

audio_signal_length = torch.tensor(audio_signal.shape[0]).long()
audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu')
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate)
audio_signal, audio_signal_length = audio_signal.to(self.device), audio_signal_length.to(self.device)
target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate).to(
self.device
)
targets = self.parse_rttm_for_targets_and_lens(
rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len
)
).to(self.device)
return audio_signal, audio_signal_length, targets, target_len


Expand Down Expand Up @@ -1355,6 +1359,7 @@ def __init__(
window_stride,
global_rank: int,
soft_targets: bool,
device: str,
):
super().__init__(
manifest_filepath=manifest_filepath,
Expand All @@ -1365,6 +1370,7 @@ def __init__(
window_stride=window_stride,
global_rank=global_rank,
soft_targets=soft_targets,
device=device,
)

def eesd_train_collate_fn(self, batch):
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __setup_dataloader_from_config(self, config):
window_stride=self._cfg.preprocessor.window_stride,
global_rank=global_rank,
soft_targets=config.soft_targets if 'soft_targets' in config else False,
device=self.device,
)

self.data_collection = dataset.collection
Expand Down Expand Up @@ -557,13 +558,13 @@ def test_batch(
audio_signal=audio_signal,
audio_signal_length=audio_signal_length,
)
self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit difficult to understand when the lists (e.g., self.preds_total_list, self.batch_f1_accs_list) are initialized in this function but are updated inside the self. _get_aux_test_batch_evaluations, but not a strong opinion~

preds = preds.detach().to('cpu')
if preds.shape[0] == 1: # batch size = 1
self.preds_total_list.append(preds)
else:
self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0]))
torch.cuda.empty_cache()
self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens)

logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}")
logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}")
Expand Down
110 changes: 110 additions & 0 deletions tests/collections/speaker_tasks/test_diar_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile

import pytest
import torch.cuda

from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines


def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec):
"""
Check if the maximum RTTM duration exceeds the length of the provided audio file.

Args:
rttm_file_path (str): Path to the RTTM file.
wav_len_in_sec (float): Length of the audio file in seconds.

Returns:
bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise.
"""
rttm_lines = read_rttm_lines(rttm_file_path)
max_rttm_sec = 0
for line in rttm_lines:
start, dur = get_vad_out_from_rttm_line(line)
max_rttm_sec = max(max_rttm_sec, start + dur)
return max_rttm_sec <= wav_len_in_sec


class TestAudioToSpeechE2ESpkDiarDataset:

@pytest.mark.unit
def test_e2e_speaker_diar_dataset(self, test_data_dir):
Fixed Show fixed Hide fixed
manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json'))

batch_size = 4
num_samples = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dict_list = []
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
with open(manifest_path, 'r', encoding='utf-8') as mfile:
for ix, line in enumerate(mfile):
if ix >= num_samples:
break

line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "")
f.write(f"{line}\n")
data_dict = json.loads(line)
data_dict_list.append(data_dict)

f.seek(0)
featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None)

dataset = AudioToSpeechE2ESpkDiarDataset(
manifest_filepath=f.name,
soft_label_thres=0.5,
session_len_sec=90,
num_spks=4,
featurizer=featurizer,
window_stride=0.01,
global_rank=0,
soft_targets=False,
device=device,
)
dataloader_instance = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=dataset.eesd_train_collate_fn,
drop_last=False,
shuffle=False,
num_workers=1,
pin_memory=False,
)
assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct
batch_counts = len(dataloader_instance)

deviation_thres_rate = 0.01 # 1% deviation allowed
for batch_index, batch in enumerate(dataloader_instance):
if batch_index != batch_counts - 1:
assert len(batch) == batch_size, "Batch size does not match the expected value"
audio_signals, audio_signal_len, targets, target_lens = batch
for sample_index in range(audio_signals.shape[0]):
dataloader_audio_in_sec = audio_signal_len[sample_index].item()
data_dur_in_sec = abs(
data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate
- dataloader_audio_in_sec
)
assert (
data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec
), "Duration deviation exceeds 1%"
assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values"
assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values"
assert not torch.isnan(targets).any(), "targets tensor contains NaN values"
assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values"
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from omegaconf import DictConfig

from nemo.collections.asr.losses import BCELoss
from nemo.collections.asr.models import EncDecDiarLabelModel


Expand All @@ -24,7 +25,12 @@ def msdd_model():

preprocessor = {
'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
'params': {"features": 80, "window_size": 0.025, "window_stride": 0.01, "sample_rate": 16000,},
'params': {
"features": 80,
"window_size": 0.025,
"window_stride": 0.01,
"sample_rate": 16000,
},
}

speaker_model_encoder = {
Expand Down Expand Up @@ -165,3 +171,37 @@ def test_forward_infer(self, msdd_model):
assert diff <= 1e-6
diff = torch.max(torch.abs(scale_weights_instance - scale_weights_batch))
assert diff <= 1e-6


class TestBCELoss:
@pytest.mark.unit
@pytest.mark.parametrize(
"probs, labels, target_lens, reduction, expected_output",
[
(
torch.tensor([[[0.5, 0.5], [0.5, 0.5]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([2]),
"mean",
torch.tensor(0.693147, dtype=torch.float32),
),
(
torch.tensor([[[0.5, 0.5], [0.0, 1.0]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([1]),
"mean",
torch.tensor(0.693147, dtype=torch.float32),
),
(
torch.tensor([[[0, 1], [1, 0]]], dtype=torch.float32),
torch.tensor([[[1, 0], [0, 1]]], dtype=torch.float32),
torch.tensor([2]),
"mean",
torch.tensor(100, dtype=torch.float32),
),
],
)
def test_loss(self, probs, labels, target_lens, reduction, expected_output):
loss = BCELoss(reduction=reduction)
result = loss(probs=probs, labels=labels, target_lens=target_lens)
assert torch.allclose(result, expected_output), f"Expected {expected_output}, but got {result}"
Loading
Loading