Skip to content

Commit

Permalink
[wip] filtering features
Browse files Browse the repository at this point in the history
  • Loading branch information
teyaberg committed Aug 20, 2024
1 parent d6832cb commit 0612730
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ model_params:

log_dir: ${model_dir}/.logs/

name: launch_basemodel
name: launch_sklearnmodel
8 changes: 5 additions & 3 deletions src/MEDS_tabular_automl/configs/tabularization/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# User inputs
allowed_codes: null
min_code_inclusion_frequency: 10
filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet
allowed_codes: null
min_code_inclusion_count: 10
min_code_inclusion_frequency: 0.01
max_included_codes: null
window_sizes:
- "1d"
- "7d"
Expand All @@ -19,4 +21,4 @@ aggs:
- "value/max"

# Resolved inputs
_resolved_codes: ${filter_to_codes:${tabularization.allowed_codes},${tabularization.min_code_inclusion_frequency},${tabularization.filtered_code_metadata_fp}}
_resolved_codes: ${filter_to_codes:${tabularization.filtered_code_metadata_fp},${tabularization.allowed_codes},${tabularization.min_code_inclusion_count},$`{tabularization.min_code_inclusion_frequency},${tabularization.max_included_codes}}`}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from loguru import logger
from omegaconf import DictConfig

from ..base_model import BaseModel
from ..sklearn_model import SklearnModel
from ..utils import hydra_loguru_init

# config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_basemodel.yaml")
# config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_sklearnmodel.yaml")
# if not config_yaml.is_file():
# raise FileNotFoundError("Core configuration not successfully installed!")
config_yaml = Path("./src/MEDS_tabular_automl/configs/launch_basemodel.yaml")
config_yaml = Path("./src/MEDS_tabular_automl/configs/launch_sklearnmodel.yaml")


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
Expand All @@ -28,7 +28,7 @@ def main(cfg: DictConfig) -> float:
if not cfg.loguru_init:
hydra_loguru_init()
try:
model = BaseModel(cfg)
model = SklearnModel(cfg)
model.train()
auc = model.evaluate()
logger.info(f"AUC: {auc}")
Expand Down
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/scripts/tabularize_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ def read_fn(_):

def compute_fn(_):
filtered_feature_columns = filter_to_codes(
cfg.input_code_metadata_fp,
cfg.tabularization.allowed_codes,
cfg.tabularization.min_code_inclusion_count,
cfg.tabularization.min_code_inclusion_frequency,
cfg.input_code_metadata_fp,
cfg.tabularization.max_included_codes,
)
feature_freqs = get_feature_freqs(cfg.input_code_metadata_fp)
filtered_feature_columns_set = set(filtered_feature_columns)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .tabular_dataset import TabularDataset


class BaseIterator(TabularDataset, TimeableMixin):
"""BaseIterator class for loading and processing data shards for use in SciKit-Learn models.
class SklearnIterator(TabularDataset, TimeableMixin):
"""SklearnIterator class for loading and processing data shards for use in SciKit-Learn models.
This class provides functionality for iterating through data shards, loading
feature data and labels, and processing them based on the provided configuration.
Expand All @@ -37,7 +37,7 @@ class BaseIterator(TabularDataset, TimeableMixin):
"""

def __init__(self, cfg: DictConfig, split: str):
"""Initializes the BaseIterator with the provided configuration and data split.
"""Initializes the SklearnIterator with the provided configuration and data split.
Args:
cfg: The configuration dictionary.
Expand All @@ -57,11 +57,11 @@ def __init__(self, cfg: DictConfig, split: str):
# function(data, labels)


class BaseMatrix(TimeableMixin):
"""BaseMatrix class for loading and processing data shards for use in SciKit-Learn models."""
class SklearnMatrix(TimeableMixin):
"""SklearnMatrix class for loading and processing data shards for use in SciKit-Learn models."""

def __init__(self, data: sp.csr_matrix, labels: np.ndarray):
"""Initializes the BaseMatrix with the provided configuration and data split.
"""Initializes the SklearnMatrix with the provided configuration and data split.
Args:
data
Expand All @@ -77,7 +77,7 @@ def get_label(self):
return self.labels


class BaseModel(TimeableMixin):
class SklearnModel(TimeableMixin):
"""Class for configuring, training, and evaluating an SciKit-Learn model.
This class utilizes the configuration settings provided to manage the training and evaluation
Expand Down Expand Up @@ -178,40 +178,54 @@ def train(self):
@TimeableMixin.TimeAs
def _build_matrix_in_memory(self):
"""Builds the DMatrix from the data in memory."""
self.dtrain = BaseMatrix(*self.itrain.get_data())
self.dtuning = BaseMatrix(*self.ituning.get_data())
self.dheld_out = BaseMatrix(*self.iheld_out.get_data())
self.dtrain = SklearnMatrix(*self.itrain.get_data())
self.dtuning = SklearnMatrix(*self.ituning.get_data())
self.dheld_out = SklearnMatrix(*self.iheld_out.get_data())

@TimeableMixin.TimeAs
def _build_iterators(self):
"""Builds the iterators for training, validation, and testing."""
self.itrain = BaseIterator(self.cfg, split="train")
self.ituning = BaseIterator(self.cfg, split="tuning")
self.iheld_out = BaseIterator(self.cfg, split="held_out")
self.itrain = SklearnIterator(self.cfg, split="train")
self.ituning = SklearnIterator(self.cfg, split="tuning")
self.iheld_out = SklearnIterator(self.cfg, split="held_out")

@TimeableMixin.TimeAs
def evaluate(self) -> float:
def evaluate(self, split: str = "tuning") -> float:
"""Evaluates the model on the tuning set.
Returns:
The evaluation metric as the ROC AUC score.
"""
# depending on split point to correct data
if split == "tuning":
dsplit = self.dtuning
isplit = self.ituning
elif split == "held_out":
dsplit = self.dheld_out
isplit = self.iheld_out
elif split == "train":
dsplit = self.dtrain
isplit = self.itrain
else:
raise ValueError(f"Split {split} is not valid.")

# check if model has predict_proba method
if not hasattr(self.model, "predict_proba"):
raise ValueError(f"Model {self.model.__class__.__name__} does not have a predict_proba method.")
# two cases: data is in memory or data is streamed
if self.keep_data_in_memory:
y_pred = self.model.predict_proba(self.dtuning.get_data())[:, 1]
y_true = self.dtuning.get_label()
y_pred = self.model.predict_proba(dsplit.get_data())[:, 1]
y_true = dsplit.get_label()
else:
y_pred = []
y_true = []
for shard_idx in range(len(self.ituning._data_shards)):
data, labels = self.ituning.get_data_shards(shard_idx)
for shard_idx in range(len(isplit._data_shards)):
data, labels = isplit.get_data_shards(shard_idx)
y_pred.extend(self.model.predict_proba(data)[:, 1])
y_true.extend(labels)
y_pred = np.array(y_pred)
y_true = np.array(y_true)

# check if y_pred and y_true are not empty
if len(y_pred) == 0 or len(y_true) == 0:
raise ValueError("Predictions or true labels are empty.")
Expand Down
140 changes: 36 additions & 104 deletions src/MEDS_tabular_automl/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(self, cfg: DictConfig, split: str = "train"):
[shard.stem for shard in list_subdir_files(Path(cfg.input_label_dir) / split, "parquet")]
)
self.valid_event_ids, self.labels = None, None
# self.valid_event_ids, self.labels = self._load_ids_and_labels()

self.codes_set, self.code_masks, self.num_features = self._get_code_set()

Expand All @@ -75,7 +74,7 @@ def _get_code_masks(self, feature_columns: list, codes_set: set) -> Mapping[str,
code_masks = {}
for agg in set(self.cfg.tabularization.aggs):
feature_ids = get_feature_indices(agg, feature_columns)
code_mask = [True if idx in codes_set else False for idx in feature_ids]
code_mask = [idx in codes_set for idx in feature_ids]
code_masks[agg] = code_mask
return code_masks

Expand Down Expand Up @@ -121,7 +120,6 @@ def _load_ids_and_labels(
if load_ids:
cached_event_ids[shard] = label_df.select(pl.col("event_id")).collect().to_series()

# TODO: check this for Nan or any other case we need to worry about
if load_labels:
cached_labels[shard] = label_df.select(pl.col("label")).collect().to_series()
if self.cfg.model_params.iterator.binarize_task:
Expand Down Expand Up @@ -171,11 +169,46 @@ def _get_code_set(self) -> tuple[set[int], Mapping[str, list[bool]], int]:
allowed_codes = set(self.cfg.tabularization._resolved_codes)
codes_set = {feature_dict[code] for code in feature_dict if code in allowed_codes}

if hasattr(self.cfg.tabularization, "max_by_correlation"):
corrs = self._get_approximate_correlation_per_feature(self.get_data_shards(0)[0], self.get_data_shards(0)[1])
corrs = np.abs(corrs)
sorted_corrs = np.argsort(corrs)[::-1]
codes_set = set(sorted_corrs[: self.cfg.tabularization.max_by_correlation])
if hasattr(self.cfg.tabularization, "min_correlation"):
corrs = self._get_approximate_correlation_per_feature(self.get_data_shards(0)[0], self.get_data_shards(0)[1])
corrs = np.abs(corrs)
codes_set = set(np.where(corrs > self.cfg.tabularization.min_correlation)[0])

return (
codes_set,
self._get_code_masks(feature_columns, codes_set),
len(feature_columns),
)

def _get_approximate_correlation_per_feature(self, X: sp.csc_matrix, y: np.ndarray) -> np.ndarray:
"""Calculates the approximate correlation of each feature with the target.
Args:
X: The feature data.
y: The target labels.
Returns:
The approximate correlation of each feature with the target.
"""
# calculate the pearson r correlation of each feature with the target
# this is a very rough approximation and should be used for feature selection
# and not as a definitive measure of feature importance

# check that y has information
if len(np.unique(y)) == 1:
raise ValueError("Labels have no information. Cannot calculate correlation.")

from scipy.stats import pearsonr
corrs = np.zeros(X.shape[1])
for i in range(X.shape[1]):
corrs[i] = pearsonr(X[:, i].toarray().flatten(), y)[0]
return corrs


@TimeableMixin.TimeAs
def _load_dynamic_shard_from_file(self, path: Path, idx: int) -> sp.csc_matrix:
Expand Down Expand Up @@ -301,99 +334,6 @@ def get_data(self) -> tuple[sp.csc_matrix, np.ndarray]:
"""
return self.get_data_shards(range(len(self._data_shards)))

def set_event_ids(self, event_ids=None | list[int]):
"""Sets the valid event IDs for each shard.
Args:
event_ids: List of event IDs for each shard.
"""
if event_ids is None:
self.valid_event_ids = self._load_event_ids()
else:
# parse some list of events they care about
pass

def set_labels(self, labels=None | list[int]):
"""Sets the labels for each shard.
Args:
labels: List of labels for each shard.
"""
if labels is None:
self.labels = self._load_labels()
else:
# parse some list of events they care about
pass

def set_codes(self, codes: list[str]):
"""Sets the codes to the passed code set. Redeclares the code masks to match.
Args:
codes: List of codes to include.
"""
self.codes_set = set(codes)
self.code_masks = self._get_code_masks(self.code_masks.keys(), self.codes_set)

def add_code(self, code: str):
"""Adds a code to the set of codes to include in the data.
Args:
code: The code to add to the set.
"""
if code not in self.codes_set:
self.codes_set.add(code)
self.code_masks = self._get_code_masks(self.code_masks.keys(), self.codes_set)

def remove_code(self, code: str):
"""Removes a code from the set of codes to include in the data.
Args:
code: The code to remove from the set.
"""
if code in self.codes_set:
self.codes_set.remove(code)
self.code_masks = self._get_code_masks(self.code_masks.keys(), self.codes_set)

def get_codes(self) -> set[str]:
"""Retrieves the set of codes to include in the data.
Returns:
The set of codes to include.
"""
return self.codes_set

def get_num_features(self) -> int:
"""Retrieves the total number of features in the data.
Returns:
The total number of features.
"""
return self.num_features

def get_valid_event_ids(self) -> Mapping[int, list]:
"""Retrieves the valid event IDs for each shard.
Returns:
A mapping from shard indices to lists of valid event IDs.
"""
return self.valid_event_ids

def get_label(self) -> Mapping[int, list]:
"""Retrieves the labels for each shard.
Returns:
A mapping from shard indices to lists of labels.
"""
return self.labels

def get_data_shard_list(self) -> list[str]:
"""Retrieves the list of data shards.
Returns:
The list of data shards.
"""
return self._data_shards

def get_data_shard_count(self) -> int:
"""Retrieves the number of data shards.
Expand All @@ -402,14 +342,6 @@ def get_data_shard_count(self) -> int:
"""
return len(self._data_shards)

def get_split(self) -> str:
"""Retrieves the data split being used.
Returns:
The data split being used.
"""
return self.split

def get_classes(self) -> int:
"""Retrieves the unique labels in the data.
Expand Down
Loading

0 comments on commit 0612730

Please sign in to comment.