Skip to content

Commit

Permalink
added docstrings and a smiple test case checking the number of subjec…
Browse files Browse the repository at this point in the history
…ts is correct when producing static representations
  • Loading branch information
Oufattole committed May 26, 2024
1 parent 4a486aa commit 19f0f4e
Show file tree
Hide file tree
Showing 5 changed files with 667 additions and 542 deletions.
3 changes: 0 additions & 3 deletions configs/tabularize.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ aggs:
- "value/last"
- "value/slope"
- "value/intercept"
- "value/residual/sum"
- "value/residual/sum_sqd"
numeric_value_impute_strategy: "drop"
dynamic_threshold: 0.01
numerical_value_threshold: 0.1

Expand Down
112 changes: 112 additions & 0 deletions src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""This module provides functions for generating static representations of patient data for use in automated
machine learning models. It includes functionality to summarize measurements based on static features and then
transform them into a tabular format suitable for analysis. The module leverages the polars library for
efficient data manipulation.
Functions:
- _summarize_static_measurements: Summarizes static measurements from a given DataFrame.
- get_flat_static_rep: Produces a tabular representation of static data features.
"""

import polars as pl

from MEDS_tabular_automl.utils import (
DF_T,
_normalize_flat_rep_df_cols,
_parse_flat_feature_column,
)


def _summarize_static_measurements(
feature_columns: list[str],
df: DF_T,
) -> pl.LazyFrame:
"""Aggregates static measurements for feature columns that are marked as 'present' or 'first'.
Parameters:
- feature_columns (list[str]): List of feature column identifiers that are specifically marked
for staticanalysis.
- df (DF_T): Data frame from which features will be extracted and summarized.
Returns:
- pl.LazyFrame: A LazyFrame containing the summarized data pivoted by 'patient_id'
for each static feature.
This function first filters for features that need to be recorded as the first occurrence
or simply as present, then performs a pivot to reshape the data for each patient, providing
a tabular format where each row represents a patient and each column represents a static feature.
"""
static_present = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("present")]
static_first = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("first")]

# Handling 'first' static values
static_first_codes = [_parse_flat_feature_column(c)[1] for c in static_first]
code_subset = df.filter(pl.col("code").is_in(static_first_codes))
first_code_subset = code_subset.groupby(pl.col("patient_id")).first().collect()
static_value_pivot_df = first_code_subset.pivot(
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None
)
# rename code to feature name
remap_cols = {
input_name: output_name
for input_name, output_name in zip(static_first_codes, static_first)
if input_name in static_value_pivot_df.columns
}
static_value_pivot_df = static_value_pivot_df.select(
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()]
)
# pivot can be faster: https://stackoverflow.com/questions/73522017/replacing-a-pivot-with-a-lazy-groupby-operation # noqa: E501
# TODO: consider casting with .cast(pl.Float32))

# Handling 'present' static indicators
static_present_codes = [_parse_flat_feature_column(c)[1] for c in static_present]
static_present_pivot_df = (
df.select(*["patient_id", "code"])
.filter(pl.col("code").is_in(static_present_codes))
.with_columns(pl.lit(True).alias("__indicator"))
.collect()
.pivot(
index=["patient_id"],
columns=["code"],
values="__indicator",
aggregate_function=None,
)
)
remap_cols = {
input_name: output_name
for input_name, output_name in zip(static_present_codes, static_present)
if input_name in static_present_pivot_df.columns
}
# rename columns to final feature names
static_present_pivot_df = static_present_pivot_df.select(
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()]
)
return pl.concat([static_value_pivot_df, static_present_pivot_df], how="align")


def get_flat_static_rep(
feature_columns: list[str],
shard_df: DF_T,
) -> pl.LazyFrame:
"""Produces a raw representation for static data from a specified shard DataFrame.
Parameters:
- feature_columns (list[str]): List of feature columns to include in the static representation.
- shard_df (DF_T): The shard DataFrame containing patient data.
Returns:
- pl.LazyFrame: A LazyFrame that includes all static features for the data provided.
This function selects the appropriate static features, summarizes them using
_summarize_static_measurements, and then normalizes the resulting data to ensure it is
suitable for further analysis or machine learning tasks.
"""
static_features = [c for c in feature_columns if c.startswith("STATIC_")]
static_measurements = _summarize_static_measurements(static_features, df=shard_df)
# fill up missing feature columns with nulls
normalized_measurements = _normalize_flat_rep_df_cols(
static_measurements,
static_features,
set_count_0_to_null=False,
)
return normalized_measurements
230 changes: 230 additions & 0 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""WIP.
This file will be used to generate time series features from the raw data.
"""
from collections.abc import Callable
from pathlib import Path

import numpy as np
import polars as pl
import polars.selectors as cs

from MEDS_tabular_automl.utils import (
DF_T,
_normalize_flat_rep_df_cols,
_parse_flat_feature_column,
)


def _summarize_dynamic_measurements(
self,
feature_columns: list[str],
include_only_subjects: set[int] | None = None,
) -> pl.LazyFrame:
if include_only_subjects is None:
df = self.dynamic_measurements_df
else:
df = self.dynamic_measurements_df.join(
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select("event_id"),
on="event_id",
how="inner",
)

valid_measures = {}
for feat_col in feature_columns:
temp, meas, feat = _parse_flat_feature_column(feat_col)

if temp != "dynamic":
continue

if meas not in valid_measures:
valid_measures[meas] = set()
valid_measures[meas].add(feat)

out_dfs = {}
for m, allowed_vocab in valid_measures.items():
cfg = self.measurement_configs[m]

total_observations = int(
np.ceil(
cfg.observation_rate_per_case
* cfg.observation_rate_over_cases
* sum(self.n_events_per_subject.values())
)
)

count_type = self.get_smallest_valid_uint_type(total_observations)

if cfg.modality == "univariate_regression" and cfg.vocabulary is None:
prefix = f"dynamic/{m}/{m}"

key_col = pl.col(m)
val_col = pl.col(m).drop_nans().cast(pl.Float32)

out_dfs[m] = (
df.lazy()
.select("measurement_id", "event_id", m)
.filter(pl.col(m).is_not_null())
.groupby("event_id")
.agg(
pl.col(m).is_not_null().sum().cast(count_type).alias(f"{prefix}/count"),
(
(pl.col(m).is_not_nan() & pl.col(m).is_not_null())
.sum()
.cast(count_type)
.alias(f"{prefix}/has_values_count")
),
val_col.sum().alias(f"{prefix}/sum"),
(val_col**2).sum().alias(f"{prefix}/sum_sqd"),
val_col.min().alias(f"{prefix}/min"),
val_col.max().alias(f"{prefix}/max"),
)
)
continue
elif cfg.modality == "multivariate_regression":
column_cols = [m, m]
values_cols = [m, cfg.values_column]
key_prefix = f"{m}_{m}_"
val_prefix = f"{cfg.values_column}_{m}_"

key_col = cs.starts_with(key_prefix)
val_col = cs.starts_with(val_prefix).drop_nans().cast(pl.Float32)

aggs = [
key_col.is_not_null()
.sum()
.cast(count_type)
.map_alias(lambda c: f"dynamic/{m}/{c.replace(key_prefix, '')}/count"),
(
(cs.starts_with(val_prefix).is_not_null() & cs.starts_with(val_prefix).is_not_nan())
.sum()
.map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/has_values_count")
),
val_col.sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum"),
(val_col**2).sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum_sqd"),
val_col.min().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/min"),
val_col.max().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/max"),
]
else:
column_cols = [m]
values_cols = [m]
aggs = [
pl.all().is_not_null().sum().cast(count_type).map_alias(lambda c: f"dynamic/{m}/{c}/count")
]

ID_cols = ["measurement_id", "event_id"]
out_dfs[m] = (
df.select(*ID_cols, *set(column_cols + values_cols))
.filter(pl.col(m).is_in(allowed_vocab))
.pivot(
index=ID_cols,
columns=column_cols,
values=values_cols,
aggregate_function=None,
)
.lazy()
.drop("measurement_id")
.groupby("event_id")
.agg(*aggs)
)

return pl.concat(list(out_dfs.values()), how="align")


def _summarize_over_window(df: DF_T, window_size: str) -> pl.LazyFrame:
"""Apply aggregations to the raw representation over a window size."""
if isinstance(df, Path):
df = pl.scan_parquet(df)

def time_aggd_col_alias_fntr(new_agg: str | None = None) -> Callable[[str], str]:
if new_agg is None:

def f(c: str) -> str:
return "/".join([window_size] + c.split("/")[1:])

else:

def f(c: str) -> str:
return "/".join([window_size] + c.split("/")[1:-1] + [new_agg])

return f

# Columns to convert to counts:
present_indicator_cols = cs.ends_with("/present")

# Columns to convert to value aggregations:
value_cols = cs.ends_with("/value")

# Columns to aggregate via other operations
cnt_cols = (cs.ends_with("/count") | cs.ends_with("/has_values_count")).fill_null(0)

cols_to_sum = cs.ends_with("/sum") | cs.ends_with("/sum_sqd")
cols_to_min = cs.ends_with("/min")
cols_to_max = cs.ends_with("/max")

if window_size == "FULL":
df = df.groupby("subject_id").agg(
"timestamp",
# present to counts
present_indicator_cols.cumsum().map_alias(time_aggd_col_alias_fntr("count")),
# values to stats
value_cols.is_not_null().cumsum().map_alias(time_aggd_col_alias_fntr("count")),
(
(value_cols.is_not_null() & value_cols.is_not_nan())
.cumsum()
.map_alias(time_aggd_col_alias_fntr("has_values_count"))
),
value_cols.cumsum().map_alias(time_aggd_col_alias_fntr("sum")),
(value_cols**2).cumsum().map_alias(time_aggd_col_alias_fntr("sum_sqd")),
value_cols.cummin().map_alias(time_aggd_col_alias_fntr("min")),
value_cols.cummax().map_alias(time_aggd_col_alias_fntr("max")),
# Raw aggregations
cnt_cols.cumsum().map_alias(time_aggd_col_alias_fntr()),
cols_to_sum.cumsum().map_alias(time_aggd_col_alias_fntr()),
cols_to_min.cummin().map_alias(time_aggd_col_alias_fntr()),
cols_to_max.cummax().map_alias(time_aggd_col_alias_fntr()),
)
df = df.explode(*[c for c in df.columns if c != "subject_id"])
else:
df = df.groupby_rolling(
index_column="timestamp",
by="subject_id",
period=window_size,
).agg(
# present to counts
present_indicator_cols.sum().map_alias(time_aggd_col_alias_fntr("count")),
# values to stats
value_cols.is_not_null().sum().map_alias(time_aggd_col_alias_fntr("count")),
(
(value_cols.is_not_null() & value_cols.is_not_nan())
.sum()
.map_alias(time_aggd_col_alias_fntr("has_values_count"))
),
value_cols.sum().map_alias(time_aggd_col_alias_fntr("sum")),
(value_cols**2).sum().map_alias(time_aggd_col_alias_fntr("sum_sqd")),
value_cols.min().map_alias(time_aggd_col_alias_fntr("min")),
value_cols.max().map_alias(time_aggd_col_alias_fntr("max")),
# Raw aggregations
cnt_cols.sum().map_alias(time_aggd_col_alias_fntr()),
cols_to_sum.sum().map_alias(time_aggd_col_alias_fntr()),
cols_to_min.min().map_alias(time_aggd_col_alias_fntr()),
cols_to_max.max().map_alias(time_aggd_col_alias_fntr()),
)

return _normalize_flat_rep_df_cols(df, set_count_0_to_null=True)


def get_flat_ts_rep(
feature_columns: list[str],
**kwargs,
) -> pl.LazyFrame:
"""Produce raw representation for dynamic data."""

return _normalize_flat_rep_df_cols(
_summarize_dynamic_measurements(feature_columns, **kwargs)
.sort(by=["subject_id", "timestamp"])
.collect()
.lazy(),
[c for c in feature_columns if c.startswith("dynamic")],
)
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason...
Loading

0 comments on commit 19f0f4e

Please sign in to comment.