Skip to content

Commit

Permalink
finished utils docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksiakolo committed Jun 12, 2024
1 parent 846a604 commit d7d72b1
Showing 1 changed file with 50 additions and 23 deletions.
73 changes: 50 additions & 23 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,16 @@ def load_meds_data(MEDS_cohort_dir: str, load_data: bool = True) -> Mapping[str,
return split_to_df


def get_events_df(shard_df: pl.DataFrame, feature_columns) -> pl.DataFrame:
"""Extracts Events DataFrame with one row per observation (timestamps can be duplicated)"""
def get_events_df(shard_df: pl.LazyFrame, feature_columns) -> pl.LazyFrame:
"""Extracts and filters an Events LazyFrame with one row per observation (timestamps can be duplicated).
Args:
shard_df: The LazyFrame shard from which to extract events.
feature_columns: The columns that define features used to filter the LazyFrame.
Returns:
A LazyFrame where each row corresponds to an event, filtered by feature columns.
"""
# Filter out feature_columns that were not present in the training set
raw_feature_columns = ["/".join(c.split("/")[:-1]) for c in feature_columns]
shard_df = shard_df.filter(pl.col("code").is_in(raw_feature_columns))
Expand All @@ -427,8 +435,15 @@ def get_events_df(shard_df: pl.DataFrame, feature_columns) -> pl.DataFrame:
return ts_shard_df


def get_unique_time_events_df(events_df: pl.DataFrame):
"""Updates Events DataFrame to have unique timestamps and sorted by patient_id and timestamp."""
def get_unique_time_events_df(events_df: pl.LazyFrame) -> pl.LazyFrame:
"""Ensures all timestamps in the events LazyFrame are unique and sorted by patient_id and timestamp.
Args:
events_df: Events LazyFrame to process.
Returns:
A LazyFrame with unique timestamps, sorted by patient_id and timestamp.
"""
assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0
# Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline
events_df = (
Expand All @@ -440,8 +455,19 @@ def get_unique_time_events_df(events_df: pl.DataFrame):
return events_df


def get_feature_names(agg, feature_columns) -> str:
"""Indices of columns in feature_columns list."""
def get_feature_names(agg: str, feature_columns: list[str]) -> str:
"""Extracts feature column names based on aggregation type from a list of column names.
Args:
agg: The aggregation type to filter by.
feature_columns: The list of feature column names.
Returns:
The filtered list of feature column names based on the aggregation type.
Raises:
ValueError: If the aggregation type is unknown or unsupported.
"""
if agg in [STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION]:
return [c for c in feature_columns if c.endswith(agg)]
elif agg in CODE_AGGREGATIONS:
Expand All @@ -452,31 +478,32 @@ def get_feature_names(agg, feature_columns) -> str:
raise ValueError(f"Unknown aggregation type {agg}")


def get_feature_indices(agg, feature_columns) -> str:
"""Indices of columns in feature_columns list."""
def get_feature_indices(agg: str, feature_columns: list[str]) -> list[int]:
"""Generates a list of feature name indices based on the aggregation type.
Args:
agg: The aggregation type used to filter feature names.
feature_columns: The list of all feature column names.
Returns:
Indices of the columns that match the aggregation type.
"""
feature_to_index = {c: i for i, c in enumerate(feature_columns)}
agg_features = get_feature_names(agg, feature_columns)
return [feature_to_index[c] for c in agg_features]


def store_config_yaml(config_fp: Path, cfg: DictConfig):
"""Stores configuration parameters into a JSON file.
def store_config_yaml(config_fp: Path, cfg: DictConfig) -> None:
"""Stores configuration parameters into a YAML file.
This function writes a dictionary of parameters, which includes patient partitioning
information and configuration details, to a specified JSON file.
information and configuration details, to a specified YAML file.
Args:
config_fp: The file path for the JSON file where config should be stored.
config_fp: The file path for the YAML file where config should be stored.
cfg: A configuration object containing settings like the number of patients
per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files.
Behavior:
- If config_fp exists and cfg.do_overwrite is False (without do_update being True), a
FileExistsError is raised to prevent unintentional data loss.
Raises:
- ValueError: If there are discrepancies between old and new parameters during an update.
- FileExistsError: If the file exists and overwriting is not allowed.
per sub-shard, minimum code inclusion frequency, and flags for updating
or overwriting existing files.
"""
OmegaConf.save(cfg, config_fp)

Expand All @@ -485,8 +512,8 @@ def get_shard_prefix(base_path: Path, fp: Path) -> str:
"""Extracts the shard prefix from a file path by removing the raw_cohort_dir.
Args:
base_path: The base path to remove.
fp: The file path to extract the shard prefix from.
base_path: The base path to remove from the file path.
fp: The full file path from which to extract the shard prefix.
Returns:
The shard prefix (the file path relative to the base path with the suffix removed).
Expand Down

0 comments on commit d7d72b1

Please sign in to comment.