From c1c3b44b9cfe5c578cea60836735660e6a73c576 Mon Sep 17 00:00:00 2001 From: FBurkhardt Date: Wed, 18 Dec 2024 11:23:39 +0100 Subject: [PATCH] 0.93.10 --- CHANGELOG.md | 5 +++++ nkululeko/constants.py | 2 +- nkululeko/feat_extract/feats_import.py | 7 ++++++- nkululeko/models/model.py | 1 + nkululeko/reporting/reporter.py | 12 ++++++++++++ 5 files changed, 25 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffb1dbe..67236b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Changelog ========= +Version 0.93.10 +-------------- +* added nan check for imported features +* added LOGO result output + Version 0.93.9 -------------- * added manual seed to torch models diff --git a/nkululeko/constants.py b/nkululeko/constants.py index 26e697d..7efc6dc 100644 --- a/nkululeko/constants.py +++ b/nkululeko/constants.py @@ -1,2 +1,2 @@ -VERSION="0.93.9" +VERSION="0.93.10" SAMPLING_RATE = 16000 diff --git a/nkululeko/feat_extract/feats_import.py b/nkululeko/feat_extract/feats_import.py index 44ad37d..708de4b 100644 --- a/nkululeko/feat_extract/feats_import.py +++ b/nkululeko/feat_extract/feats_import.py @@ -30,7 +30,7 @@ def extract(self): "feature type == import needs import_file = ['file1', 'filex']" ) except SyntaxError: - if type(feat_import_files) == str: + if type(feat_import_files) is str: feat_import_files = [feat_import_files] else: self.util.error(f"import_file is wrong: {feat_import_files}") @@ -40,6 +40,11 @@ def extract(self): if not os.path.isfile(feat_import_file): self.util.error(f"no import file: {feat_import_file}") df = audformat.utils.read_csv(feat_import_file) + if df.isnull().values.any(): + self.util.warn( + f"imported features contain {df.isna().sum()} NAN, filling with zero." + ) + df = df.fillna(0) df = self.util.make_segmented_index(df) df = df[df.index.isin(self.data_df.index)] if import_files_append: diff --git a/nkululeko/models/model.py b/nkululeko/models/model.py index 57b93bb..97b2097 100644 --- a/nkululeko/models/model.py +++ b/nkululeko/models/model.py @@ -171,6 +171,7 @@ def _do_logo(self): f"LOGO: {self.logo} folds: mean {results.mean():.3f}, std:" f" {results.std():.3f}" ) + report.print_logo(results) def train(self): """Train the model.""" diff --git a/nkululeko/reporting/reporter.py b/nkululeko/reporting/reporter.py index d81a7bb..c540e3d 100644 --- a/nkululeko/reporting/reporter.py +++ b/nkululeko/reporting/reporter.py @@ -380,6 +380,18 @@ def _plot_confmat(self, truths, preds, plot_name, epoch=None, test_result=None): def set_filename_add(self, my_string): self.filenameadd = f"_{my_string}" + def print_logo(self, results): + res_dir = self.util.get_path("res_dir") + result_str = f"LOGO results: [{','.join(results.astype(str))}]" + file_name = f"{res_dir}/logo_results.txt" + with open(file_name, "w") as text_file: + text_file.write( + f"LOGO: mean {results.mean():.3f}, std: " + f"{results.std():.3f}" + ) + text_file.write("\n") + text_file.write(result_str) + self.util.debug(result_str) + def print_results(self, epoch=None): if epoch is None: epoch = self.epoch