From 0d7ed275d4c4c5a7fb21a527ddf2016700fb36e3 Mon Sep 17 00:00:00 2001 From: Nassim Oufattole Date: Wed, 21 Aug 2024 06:41:41 +0000 Subject: [PATCH] fixed bugs so correlation code filters work now --- .../configs/tabularization/default.yaml | 2 ++ src/MEDS_tabular_automl/tabular_dataset.py | 10 +++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/MEDS_tabular_automl/configs/tabularization/default.yaml b/src/MEDS_tabular_automl/configs/tabularization/default.yaml index 6fc3703..ada7dc9 100644 --- a/src/MEDS_tabular_automl/configs/tabularization/default.yaml +++ b/src/MEDS_tabular_automl/configs/tabularization/default.yaml @@ -3,6 +3,8 @@ filtered_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet allowed_codes: null min_code_inclusion_count: 10 min_code_inclusion_frequency: null +min_correlation: null +max_by_correlation: null max_included_codes: null window_sizes: - "1d" diff --git a/src/MEDS_tabular_automl/tabular_dataset.py b/src/MEDS_tabular_automl/tabular_dataset.py index 5a6ba43..84a6609 100644 --- a/src/MEDS_tabular_automl/tabular_dataset.py +++ b/src/MEDS_tabular_automl/tabular_dataset.py @@ -177,9 +177,7 @@ def _get_code_set(self) -> tuple[set[int], Mapping[str, list[bool]], int]: hasattr(self.cfg.tabularization, "max_by_correlation") and self.cfg.tabularization.max_by_correlation ): - corrs = self._get_approximate_correlation_per_feature( - self.get_data_shards(0)[0], self.get_data_shards(0)[1] - ) + corrs = self._get_approximate_correlation_per_feature(*self._get_shard_by_index(0)) corrs = np.abs(corrs) sorted_corrs = np.argsort(corrs)[::-1] @@ -187,9 +185,7 @@ def _get_code_set(self) -> tuple[set[int], Mapping[str, list[bool]], int]: set(sorted_corrs[: self.cfg.tabularization.max_by_correlation]) ) if hasattr(self.cfg.tabularization, "min_correlation") and self.cfg.tabularization.min_correlation: - corrs = self._get_approximate_correlation_per_feature( - self.get_data_shards(0)[0], self.get_data_shards(0)[1] - ) + corrs = self._get_approximate_correlation_per_feature(*self._get_shard_by_index(0)) corrs = np.abs(corrs) codes_set = codes_set.intersection( set(np.where(corrs > self.cfg.tabularization.min_correlation)[0]) @@ -356,7 +352,7 @@ def _filter_shard_on_codes_and_freqs(self, agg: str, df: sp.csc_matrix) -> sp.cs Returns: The filtered data frame. """ - if self.codes_set is None: + if not hasattr(self, "codes_set") or self.codes_set is None: return df ckey = f"_filter_shard_on_codes_and_freqs/{agg}"