Skip to content

Commit

Permalink
fixed bugs so correlation code filters work now
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Aug 21, 2024
1 parent 527eda5 commit 0d7ed27
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/MEDS_tabular_automl/configs/tabularization/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ filtered_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
allowed_codes: null
min_code_inclusion_count: 10
min_code_inclusion_frequency: null
min_correlation: null
max_by_correlation: null
max_included_codes: null
window_sizes:
- "1d"
Expand Down
10 changes: 3 additions & 7 deletions src/MEDS_tabular_automl/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,15 @@ def _get_code_set(self) -> tuple[set[int], Mapping[str, list[bool]], int]:
hasattr(self.cfg.tabularization, "max_by_correlation")
and self.cfg.tabularization.max_by_correlation
):
corrs = self._get_approximate_correlation_per_feature(
self.get_data_shards(0)[0], self.get_data_shards(0)[1]
)
corrs = self._get_approximate_correlation_per_feature(*self._get_shard_by_index(0))
corrs = np.abs(corrs)
sorted_corrs = np.argsort(corrs)[::-1]

codes_set = codes_set.intersection(
set(sorted_corrs[: self.cfg.tabularization.max_by_correlation])
)
if hasattr(self.cfg.tabularization, "min_correlation") and self.cfg.tabularization.min_correlation:
corrs = self._get_approximate_correlation_per_feature(
self.get_data_shards(0)[0], self.get_data_shards(0)[1]
)
corrs = self._get_approximate_correlation_per_feature(*self._get_shard_by_index(0))
corrs = np.abs(corrs)
codes_set = codes_set.intersection(
set(np.where(corrs > self.cfg.tabularization.min_correlation)[0])
Expand Down Expand Up @@ -356,7 +352,7 @@ def _filter_shard_on_codes_and_freqs(self, agg: str, df: sp.csc_matrix) -> sp.cs
Returns:
The filtered data frame.
"""
if self.codes_set is None:
if not hasattr(self, "codes_set") or self.codes_set is None:
return df

ckey = f"_filter_shard_on_codes_and_freqs/{agg}"
Expand Down

0 comments on commit 0d7ed27

Please sign in to comment.