-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added docstrings and a smiple test case checking the number of subjec…
…ts is correct when producing static representations
- Loading branch information
Showing
5 changed files
with
667 additions
and
542 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
"""This module provides functions for generating static representations of patient data for use in automated | ||
machine learning models. It includes functionality to summarize measurements based on static features and then | ||
transform them into a tabular format suitable for analysis. The module leverages the polars library for | ||
efficient data manipulation. | ||
Functions: | ||
- _summarize_static_measurements: Summarizes static measurements from a given DataFrame. | ||
- get_flat_static_rep: Produces a tabular representation of static data features. | ||
""" | ||
|
||
import polars as pl | ||
|
||
from MEDS_tabular_automl.utils import ( | ||
DF_T, | ||
_normalize_flat_rep_df_cols, | ||
_parse_flat_feature_column, | ||
) | ||
|
||
|
||
def _summarize_static_measurements( | ||
feature_columns: list[str], | ||
df: DF_T, | ||
) -> pl.LazyFrame: | ||
"""Aggregates static measurements for feature columns that are marked as 'present' or 'first'. | ||
Parameters: | ||
- feature_columns (list[str]): List of feature column identifiers that are specifically marked | ||
for staticanalysis. | ||
- df (DF_T): Data frame from which features will be extracted and summarized. | ||
Returns: | ||
- pl.LazyFrame: A LazyFrame containing the summarized data pivoted by 'patient_id' | ||
for each static feature. | ||
This function first filters for features that need to be recorded as the first occurrence | ||
or simply as present, then performs a pivot to reshape the data for each patient, providing | ||
a tabular format where each row represents a patient and each column represents a static feature. | ||
""" | ||
static_present = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("present")] | ||
static_first = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("first")] | ||
|
||
# Handling 'first' static values | ||
static_first_codes = [_parse_flat_feature_column(c)[1] for c in static_first] | ||
code_subset = df.filter(pl.col("code").is_in(static_first_codes)) | ||
first_code_subset = code_subset.groupby(pl.col("patient_id")).first().collect() | ||
static_value_pivot_df = first_code_subset.pivot( | ||
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None | ||
) | ||
# rename code to feature name | ||
remap_cols = { | ||
input_name: output_name | ||
for input_name, output_name in zip(static_first_codes, static_first) | ||
if input_name in static_value_pivot_df.columns | ||
} | ||
static_value_pivot_df = static_value_pivot_df.select( | ||
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] | ||
) | ||
# pivot can be faster: https://stackoverflow.com/questions/73522017/replacing-a-pivot-with-a-lazy-groupby-operation # noqa: E501 | ||
# TODO: consider casting with .cast(pl.Float32)) | ||
|
||
# Handling 'present' static indicators | ||
static_present_codes = [_parse_flat_feature_column(c)[1] for c in static_present] | ||
static_present_pivot_df = ( | ||
df.select(*["patient_id", "code"]) | ||
.filter(pl.col("code").is_in(static_present_codes)) | ||
.with_columns(pl.lit(True).alias("__indicator")) | ||
.collect() | ||
.pivot( | ||
index=["patient_id"], | ||
columns=["code"], | ||
values="__indicator", | ||
aggregate_function=None, | ||
) | ||
) | ||
remap_cols = { | ||
input_name: output_name | ||
for input_name, output_name in zip(static_present_codes, static_present) | ||
if input_name in static_present_pivot_df.columns | ||
} | ||
# rename columns to final feature names | ||
static_present_pivot_df = static_present_pivot_df.select( | ||
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] | ||
) | ||
return pl.concat([static_value_pivot_df, static_present_pivot_df], how="align") | ||
|
||
|
||
def get_flat_static_rep( | ||
feature_columns: list[str], | ||
shard_df: DF_T, | ||
) -> pl.LazyFrame: | ||
"""Produces a raw representation for static data from a specified shard DataFrame. | ||
Parameters: | ||
- feature_columns (list[str]): List of feature columns to include in the static representation. | ||
- shard_df (DF_T): The shard DataFrame containing patient data. | ||
Returns: | ||
- pl.LazyFrame: A LazyFrame that includes all static features for the data provided. | ||
This function selects the appropriate static features, summarizes them using | ||
_summarize_static_measurements, and then normalizes the resulting data to ensure it is | ||
suitable for further analysis or machine learning tasks. | ||
""" | ||
static_features = [c for c in feature_columns if c.startswith("STATIC_")] | ||
static_measurements = _summarize_static_measurements(static_features, df=shard_df) | ||
# fill up missing feature columns with nulls | ||
normalized_measurements = _normalize_flat_rep_df_cols( | ||
static_measurements, | ||
static_features, | ||
set_count_0_to_null=False, | ||
) | ||
return normalized_measurements |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
"""WIP. | ||
This file will be used to generate time series features from the raw data. | ||
""" | ||
from collections.abc import Callable | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import polars as pl | ||
import polars.selectors as cs | ||
|
||
from MEDS_tabular_automl.utils import ( | ||
DF_T, | ||
_normalize_flat_rep_df_cols, | ||
_parse_flat_feature_column, | ||
) | ||
|
||
|
||
def _summarize_dynamic_measurements( | ||
self, | ||
feature_columns: list[str], | ||
include_only_subjects: set[int] | None = None, | ||
) -> pl.LazyFrame: | ||
if include_only_subjects is None: | ||
df = self.dynamic_measurements_df | ||
else: | ||
df = self.dynamic_measurements_df.join( | ||
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select("event_id"), | ||
on="event_id", | ||
how="inner", | ||
) | ||
|
||
valid_measures = {} | ||
for feat_col in feature_columns: | ||
temp, meas, feat = _parse_flat_feature_column(feat_col) | ||
|
||
if temp != "dynamic": | ||
continue | ||
|
||
if meas not in valid_measures: | ||
valid_measures[meas] = set() | ||
valid_measures[meas].add(feat) | ||
|
||
out_dfs = {} | ||
for m, allowed_vocab in valid_measures.items(): | ||
cfg = self.measurement_configs[m] | ||
|
||
total_observations = int( | ||
np.ceil( | ||
cfg.observation_rate_per_case | ||
* cfg.observation_rate_over_cases | ||
* sum(self.n_events_per_subject.values()) | ||
) | ||
) | ||
|
||
count_type = self.get_smallest_valid_uint_type(total_observations) | ||
|
||
if cfg.modality == "univariate_regression" and cfg.vocabulary is None: | ||
prefix = f"dynamic/{m}/{m}" | ||
|
||
key_col = pl.col(m) | ||
val_col = pl.col(m).drop_nans().cast(pl.Float32) | ||
|
||
out_dfs[m] = ( | ||
df.lazy() | ||
.select("measurement_id", "event_id", m) | ||
.filter(pl.col(m).is_not_null()) | ||
.groupby("event_id") | ||
.agg( | ||
pl.col(m).is_not_null().sum().cast(count_type).alias(f"{prefix}/count"), | ||
( | ||
(pl.col(m).is_not_nan() & pl.col(m).is_not_null()) | ||
.sum() | ||
.cast(count_type) | ||
.alias(f"{prefix}/has_values_count") | ||
), | ||
val_col.sum().alias(f"{prefix}/sum"), | ||
(val_col**2).sum().alias(f"{prefix}/sum_sqd"), | ||
val_col.min().alias(f"{prefix}/min"), | ||
val_col.max().alias(f"{prefix}/max"), | ||
) | ||
) | ||
continue | ||
elif cfg.modality == "multivariate_regression": | ||
column_cols = [m, m] | ||
values_cols = [m, cfg.values_column] | ||
key_prefix = f"{m}_{m}_" | ||
val_prefix = f"{cfg.values_column}_{m}_" | ||
|
||
key_col = cs.starts_with(key_prefix) | ||
val_col = cs.starts_with(val_prefix).drop_nans().cast(pl.Float32) | ||
|
||
aggs = [ | ||
key_col.is_not_null() | ||
.sum() | ||
.cast(count_type) | ||
.map_alias(lambda c: f"dynamic/{m}/{c.replace(key_prefix, '')}/count"), | ||
( | ||
(cs.starts_with(val_prefix).is_not_null() & cs.starts_with(val_prefix).is_not_nan()) | ||
.sum() | ||
.map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/has_values_count") | ||
), | ||
val_col.sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum"), | ||
(val_col**2).sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum_sqd"), | ||
val_col.min().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/min"), | ||
val_col.max().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/max"), | ||
] | ||
else: | ||
column_cols = [m] | ||
values_cols = [m] | ||
aggs = [ | ||
pl.all().is_not_null().sum().cast(count_type).map_alias(lambda c: f"dynamic/{m}/{c}/count") | ||
] | ||
|
||
ID_cols = ["measurement_id", "event_id"] | ||
out_dfs[m] = ( | ||
df.select(*ID_cols, *set(column_cols + values_cols)) | ||
.filter(pl.col(m).is_in(allowed_vocab)) | ||
.pivot( | ||
index=ID_cols, | ||
columns=column_cols, | ||
values=values_cols, | ||
aggregate_function=None, | ||
) | ||
.lazy() | ||
.drop("measurement_id") | ||
.groupby("event_id") | ||
.agg(*aggs) | ||
) | ||
|
||
return pl.concat(list(out_dfs.values()), how="align") | ||
|
||
|
||
def _summarize_over_window(df: DF_T, window_size: str) -> pl.LazyFrame: | ||
"""Apply aggregations to the raw representation over a window size.""" | ||
if isinstance(df, Path): | ||
df = pl.scan_parquet(df) | ||
|
||
def time_aggd_col_alias_fntr(new_agg: str | None = None) -> Callable[[str], str]: | ||
if new_agg is None: | ||
|
||
def f(c: str) -> str: | ||
return "/".join([window_size] + c.split("/")[1:]) | ||
|
||
else: | ||
|
||
def f(c: str) -> str: | ||
return "/".join([window_size] + c.split("/")[1:-1] + [new_agg]) | ||
|
||
return f | ||
|
||
# Columns to convert to counts: | ||
present_indicator_cols = cs.ends_with("/present") | ||
|
||
# Columns to convert to value aggregations: | ||
value_cols = cs.ends_with("/value") | ||
|
||
# Columns to aggregate via other operations | ||
cnt_cols = (cs.ends_with("/count") | cs.ends_with("/has_values_count")).fill_null(0) | ||
|
||
cols_to_sum = cs.ends_with("/sum") | cs.ends_with("/sum_sqd") | ||
cols_to_min = cs.ends_with("/min") | ||
cols_to_max = cs.ends_with("/max") | ||
|
||
if window_size == "FULL": | ||
df = df.groupby("subject_id").agg( | ||
"timestamp", | ||
# present to counts | ||
present_indicator_cols.cumsum().map_alias(time_aggd_col_alias_fntr("count")), | ||
# values to stats | ||
value_cols.is_not_null().cumsum().map_alias(time_aggd_col_alias_fntr("count")), | ||
( | ||
(value_cols.is_not_null() & value_cols.is_not_nan()) | ||
.cumsum() | ||
.map_alias(time_aggd_col_alias_fntr("has_values_count")) | ||
), | ||
value_cols.cumsum().map_alias(time_aggd_col_alias_fntr("sum")), | ||
(value_cols**2).cumsum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), | ||
value_cols.cummin().map_alias(time_aggd_col_alias_fntr("min")), | ||
value_cols.cummax().map_alias(time_aggd_col_alias_fntr("max")), | ||
# Raw aggregations | ||
cnt_cols.cumsum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_sum.cumsum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_min.cummin().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_max.cummax().map_alias(time_aggd_col_alias_fntr()), | ||
) | ||
df = df.explode(*[c for c in df.columns if c != "subject_id"]) | ||
else: | ||
df = df.groupby_rolling( | ||
index_column="timestamp", | ||
by="subject_id", | ||
period=window_size, | ||
).agg( | ||
# present to counts | ||
present_indicator_cols.sum().map_alias(time_aggd_col_alias_fntr("count")), | ||
# values to stats | ||
value_cols.is_not_null().sum().map_alias(time_aggd_col_alias_fntr("count")), | ||
( | ||
(value_cols.is_not_null() & value_cols.is_not_nan()) | ||
.sum() | ||
.map_alias(time_aggd_col_alias_fntr("has_values_count")) | ||
), | ||
value_cols.sum().map_alias(time_aggd_col_alias_fntr("sum")), | ||
(value_cols**2).sum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), | ||
value_cols.min().map_alias(time_aggd_col_alias_fntr("min")), | ||
value_cols.max().map_alias(time_aggd_col_alias_fntr("max")), | ||
# Raw aggregations | ||
cnt_cols.sum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_sum.sum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_min.min().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_max.max().map_alias(time_aggd_col_alias_fntr()), | ||
) | ||
|
||
return _normalize_flat_rep_df_cols(df, set_count_0_to_null=True) | ||
|
||
|
||
def get_flat_ts_rep( | ||
feature_columns: list[str], | ||
**kwargs, | ||
) -> pl.LazyFrame: | ||
"""Produce raw representation for dynamic data.""" | ||
|
||
return _normalize_flat_rep_df_cols( | ||
_summarize_dynamic_measurements(feature_columns, **kwargs) | ||
.sort(by=["subject_id", "timestamp"]) | ||
.collect() | ||
.lazy(), | ||
[c for c in feature_columns if c.startswith("dynamic")], | ||
) | ||
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason... |
Oops, something went wrong.