-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Converting Esgpt caching to work for MEDS datasets #1
Merged
Merged
Changes from 3 commits
Commits
Show all changes
75 commits
Select commit
Hold shift + click to select a range
d5ff0df
added test case and initial testing and and code for shardingoutput t…
Oufattole 4a486aa
added static feature pivoting
Oufattole 19f0f4e
added docstrings and a smiple test case checking the number of subjec…
Oufattole fd1731f
Refactor scripts into separate modules for improved clarity:
Oufattole 63b9ba6
fixed doctests and updated github workflow tests to use python 3.12
Oufattole cd067f8
Implement data processing for MEDS format to pivot tables into two in…
Oufattole c8ca3bb
Enhance data aggregation framework with dynamic window and aggregatio…
Oufattole 7fdc37d
Update src/MEDS_tabular_automl/generate_summarized_reps.py
mmcdermott 4dd3cad
Update src/MEDS_tabular_automl/generate_summarized_reps.py
mmcdermott 720a533
Update src/MEDS_tabular_automl/generate_summarized_reps.py
mmcdermott 548e29a
Added doctest and updated docstrings in identiy_columns.py. [WIP] add…
Oufattole d9ba7e7
Merge branch 'esgpt_caching' of github.com:mmcdermott/MEDS_Tabular_Au…
Oufattole f0b1cbb
working on xgboost
teyaberg ba954ef
current state
Oufattole df2750a
Removed tqdm, fixed deprecated groupbys, fixed doctest long-line issue.
mmcdermott 4bbbc20
Fixed one of the summary doctests.
mmcdermott d39bf1a
updates based on formats... still many to dos
teyaberg 41fe4b4
using sparse matrices for generating time series representations
Oufattole cb5f689
still working on sparse matrix to external memory xgboost
teyaberg 8bc9a16
same problem
teyaberg 1e27526
cleaned some testing
teyaberg 97938a8
sped up the tabularize_ts script by about 30% by concatenating the sp…
Oufattole c28e6b2
got iterator working with csr_matrices for X and numpy arrays for y
teyaberg 6f3b1ec
added support for sparse aggregations
Oufattole 6753609
passing unit tests for sparse aggregations (only code/count and value…
Oufattole f125600
added significant speed improvements for rolling window aggregations
Oufattole 2acc3bc
improved speed, by removing conversion from sparse scipy matrix to sp…
Oufattole eec05e2
takes about an hour to run through a shard. The speed gain is from me…
Oufattole bd9bdae
added scripts to the readme
Oufattole 29c8c5f
save before breaking it
teyaberg 4c7d3e7
added support for parallelism using mapper warp function. We cache fe…
Oufattole ba796e5
wip
teyaberg 3678d30
automl
teyaberg 82b3903
Merge branch 'esgpt_caching' into xgboost
Oufattole ffa0f3c
working on collect_in_memory
teyaberg c8f26ea
collect in memory fixed
teyaberg f6a3751
added hf_cohort scripts
2ec1860
Apply suggestions from code review
mmcdermott db18dc5
cleaning
teyaberg abba3d2
local WIP--changing to sparse matrix implementation
teyaberg 77f296f
added merging of static and time series data
e8f26eb
Merge branch 'esgpt_caching' of github.com:mmcdermott/MEDS_Tabular_Au…
958906d
merging script runs, but the output is 50GB
Oufattole 7668382
merging script works and is efficient
Oufattole b6b8d43
fixed bug with sparse matrix shape being too small for merging static…
Oufattole e6a88a7
changed to sparse format
teyaberg e8d64fd
added script for extracting tasks using aces
Oufattole 5b2f7f7
merged xgboost code
Oufattole 5c5dc8e
added dependencies
Oufattole 357845e
Merge branch 'xgboost' into esgpt_caching
Oufattole d99e274
added support for loading cached labels and event indexes
Oufattole cadc603
updated readme
Oufattole 285ccbf
size issues for loading sparse matrix
teyaberg 795b532
push updates
teyaberg b9d057b
4x speed increase for tabularization to sparse matrix by caching wind…
Oufattole 85f38b5
Merge branch 'xgboost' into esgpt_caching
Oufattole 7ea3230
standardized file storage using file_name.py and updated from using n…
Oufattole 23a2e3b
cleaned up file paths so we can load all aggregations selectively and…
Oufattole c225c47
fixed bug with codes that are only in the test and validation set (no…
cb21821
fixed bug with summarization script crashing for min and max value ag…
3a412a0
removed overwrite killing of jobs which causes errors in multirun
a4f1843
Xgboost is able to load all concatenated windows and aggregations. Fi…
Oufattole 800ab7e
fixed timedelta overflow bug
820e194
Merge branch 'esgpt_caching' of github.com:mmcdermott/MEDS_Tabular_Au…
4b0637a
fixed bug with loading feature columns json for aces task script
127d04a
added memory profiling to hf_cohort e2e script
Oufattole 23877ad
Merge branch 'esgpt_caching' of github.com:mmcdermott/MEDS_Tabular_Au…
Oufattole 36f54a3
Made tests ignore the hf_cohort directory
mmcdermott 81bf2d9
Pre-commit fixes
mmcdermott 83c4eec
Resolving deprecation warnings
mmcdermott e7a85ba
Fixed test installation instructions.
mmcdermott 35acb97
Merge branch 'esgpt_caching' into mmd_changes
mmcdermott bef63b6
Resolved one error (or, rather, shifted it) by making some things pro…
mmcdermott e9775e2
Shifted more test errors around, but the failures are deeper than exp…
mmcdermott c8f4144
Merge pull request #4 from mmcdermott/mmd_changes
mmcdermott File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
"""This module provides functions for generating static representations of patient data for use in automated | ||
machine learning models. It includes functionality to summarize measurements based on static features and then | ||
transform them into a tabular format suitable for analysis. The module leverages the polars library for | ||
efficient data manipulation. | ||
|
||
Functions: | ||
- _summarize_static_measurements: Summarizes static measurements from a given DataFrame. | ||
- get_flat_static_rep: Produces a tabular representation of static data features. | ||
""" | ||
|
||
import polars as pl | ||
|
||
from MEDS_tabular_automl.utils import ( | ||
DF_T, | ||
_normalize_flat_rep_df_cols, | ||
_parse_flat_feature_column, | ||
) | ||
|
||
|
||
def _summarize_static_measurements( | ||
feature_columns: list[str], | ||
df: DF_T, | ||
) -> pl.LazyFrame: | ||
"""Aggregates static measurements for feature columns that are marked as 'present' or 'first'. | ||
|
||
Parameters: | ||
- feature_columns (list[str]): List of feature column identifiers that are specifically marked | ||
for staticanalysis. | ||
- df (DF_T): Data frame from which features will be extracted and summarized. | ||
|
||
Returns: | ||
- pl.LazyFrame: A LazyFrame containing the summarized data pivoted by 'patient_id' | ||
for each static feature. | ||
|
||
This function first filters for features that need to be recorded as the first occurrence | ||
or simply as present, then performs a pivot to reshape the data for each patient, providing | ||
a tabular format where each row represents a patient and each column represents a static feature. | ||
""" | ||
static_present = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("present")] | ||
static_first = [c for c in feature_columns if c.startswith("STATIC_") and c.endswith("first")] | ||
|
||
# Handling 'first' static values | ||
static_first_codes = [_parse_flat_feature_column(c)[1] for c in static_first] | ||
code_subset = df.filter(pl.col("code").is_in(static_first_codes)) | ||
first_code_subset = code_subset.groupby(pl.col("patient_id")).first().collect() | ||
static_value_pivot_df = first_code_subset.pivot( | ||
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None | ||
) | ||
# rename code to feature name | ||
remap_cols = { | ||
input_name: output_name | ||
for input_name, output_name in zip(static_first_codes, static_first) | ||
if input_name in static_value_pivot_df.columns | ||
} | ||
static_value_pivot_df = static_value_pivot_df.select( | ||
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] | ||
) | ||
# pivot can be faster: https://stackoverflow.com/questions/73522017/replacing-a-pivot-with-a-lazy-groupby-operation # noqa: E501 | ||
# TODO: consider casting with .cast(pl.Float32)) | ||
|
||
# Handling 'present' static indicators | ||
static_present_codes = [_parse_flat_feature_column(c)[1] for c in static_present] | ||
static_present_pivot_df = ( | ||
df.select(*["patient_id", "code"]) | ||
.filter(pl.col("code").is_in(static_present_codes)) | ||
.with_columns(pl.lit(True).alias("__indicator")) | ||
.collect() | ||
.pivot( | ||
index=["patient_id"], | ||
columns=["code"], | ||
values="__indicator", | ||
aggregate_function=None, | ||
) | ||
) | ||
remap_cols = { | ||
input_name: output_name | ||
for input_name, output_name in zip(static_present_codes, static_present) | ||
if input_name in static_present_pivot_df.columns | ||
} | ||
# rename columns to final feature names | ||
static_present_pivot_df = static_present_pivot_df.select( | ||
*["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] | ||
) | ||
return pl.concat([static_value_pivot_df, static_present_pivot_df], how="align") | ||
|
||
|
||
def get_flat_static_rep( | ||
feature_columns: list[str], | ||
shard_df: DF_T, | ||
) -> pl.LazyFrame: | ||
"""Produces a raw representation for static data from a specified shard DataFrame. | ||
|
||
Parameters: | ||
- feature_columns (list[str]): List of feature columns to include in the static representation. | ||
- shard_df (DF_T): The shard DataFrame containing patient data. | ||
|
||
Returns: | ||
- pl.LazyFrame: A LazyFrame that includes all static features for the data provided. | ||
|
||
This function selects the appropriate static features, summarizes them using | ||
_summarize_static_measurements, and then normalizes the resulting data to ensure it is | ||
suitable for further analysis or machine learning tasks. | ||
""" | ||
static_features = [c for c in feature_columns if c.startswith("STATIC_")] | ||
static_measurements = _summarize_static_measurements(static_features, df=shard_df) | ||
# fill up missing feature columns with nulls | ||
normalized_measurements = _normalize_flat_rep_df_cols( | ||
static_measurements, | ||
static_features, | ||
set_count_0_to_null=False, | ||
) | ||
return normalized_measurements |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
"""WIP. | ||
|
||
This file will be used to generate time series features from the raw data. | ||
""" | ||
from collections.abc import Callable | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import polars as pl | ||
import polars.selectors as cs | ||
|
||
from MEDS_tabular_automl.utils import ( | ||
DF_T, | ||
_normalize_flat_rep_df_cols, | ||
_parse_flat_feature_column, | ||
) | ||
|
||
|
||
def _summarize_dynamic_measurements( | ||
self, | ||
feature_columns: list[str], | ||
include_only_subjects: set[int] | None = None, | ||
) -> pl.LazyFrame: | ||
if include_only_subjects is None: | ||
df = self.dynamic_measurements_df | ||
else: | ||
df = self.dynamic_measurements_df.join( | ||
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select("event_id"), | ||
on="event_id", | ||
how="inner", | ||
) | ||
|
||
valid_measures = {} | ||
for feat_col in feature_columns: | ||
temp, meas, feat = _parse_flat_feature_column(feat_col) | ||
|
||
if temp != "dynamic": | ||
continue | ||
|
||
if meas not in valid_measures: | ||
valid_measures[meas] = set() | ||
valid_measures[meas].add(feat) | ||
|
||
out_dfs = {} | ||
for m, allowed_vocab in valid_measures.items(): | ||
cfg = self.measurement_configs[m] | ||
|
||
total_observations = int( | ||
np.ceil( | ||
cfg.observation_rate_per_case | ||
* cfg.observation_rate_over_cases | ||
* sum(self.n_events_per_subject.values()) | ||
) | ||
) | ||
|
||
count_type = self.get_smallest_valid_uint_type(total_observations) | ||
|
||
if cfg.modality == "univariate_regression" and cfg.vocabulary is None: | ||
prefix = f"dynamic/{m}/{m}" | ||
|
||
key_col = pl.col(m) | ||
val_col = pl.col(m).drop_nans().cast(pl.Float32) | ||
|
||
out_dfs[m] = ( | ||
df.lazy() | ||
.select("measurement_id", "event_id", m) | ||
.filter(pl.col(m).is_not_null()) | ||
.groupby("event_id") | ||
.agg( | ||
pl.col(m).is_not_null().sum().cast(count_type).alias(f"{prefix}/count"), | ||
( | ||
(pl.col(m).is_not_nan() & pl.col(m).is_not_null()) | ||
.sum() | ||
.cast(count_type) | ||
.alias(f"{prefix}/has_values_count") | ||
), | ||
val_col.sum().alias(f"{prefix}/sum"), | ||
(val_col**2).sum().alias(f"{prefix}/sum_sqd"), | ||
val_col.min().alias(f"{prefix}/min"), | ||
val_col.max().alias(f"{prefix}/max"), | ||
) | ||
) | ||
continue | ||
elif cfg.modality == "multivariate_regression": | ||
column_cols = [m, m] | ||
values_cols = [m, cfg.values_column] | ||
key_prefix = f"{m}_{m}_" | ||
val_prefix = f"{cfg.values_column}_{m}_" | ||
|
||
key_col = cs.starts_with(key_prefix) | ||
val_col = cs.starts_with(val_prefix).drop_nans().cast(pl.Float32) | ||
|
||
aggs = [ | ||
key_col.is_not_null() | ||
.sum() | ||
.cast(count_type) | ||
.map_alias(lambda c: f"dynamic/{m}/{c.replace(key_prefix, '')}/count"), | ||
( | ||
(cs.starts_with(val_prefix).is_not_null() & cs.starts_with(val_prefix).is_not_nan()) | ||
.sum() | ||
.map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/has_values_count") | ||
), | ||
val_col.sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum"), | ||
(val_col**2).sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum_sqd"), | ||
val_col.min().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/min"), | ||
val_col.max().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/max"), | ||
] | ||
else: | ||
column_cols = [m] | ||
values_cols = [m] | ||
aggs = [ | ||
pl.all().is_not_null().sum().cast(count_type).map_alias(lambda c: f"dynamic/{m}/{c}/count") | ||
] | ||
|
||
ID_cols = ["measurement_id", "event_id"] | ||
out_dfs[m] = ( | ||
df.select(*ID_cols, *set(column_cols + values_cols)) | ||
.filter(pl.col(m).is_in(allowed_vocab)) | ||
.pivot( | ||
index=ID_cols, | ||
columns=column_cols, | ||
values=values_cols, | ||
aggregate_function=None, | ||
) | ||
.lazy() | ||
.drop("measurement_id") | ||
.groupby("event_id") | ||
.agg(*aggs) | ||
) | ||
|
||
return pl.concat(list(out_dfs.values()), how="align") | ||
|
||
|
||
def _summarize_over_window(df: DF_T, window_size: str) -> pl.LazyFrame: | ||
"""Apply aggregations to the raw representation over a window size.""" | ||
if isinstance(df, Path): | ||
df = pl.scan_parquet(df) | ||
|
||
def time_aggd_col_alias_fntr(new_agg: str | None = None) -> Callable[[str], str]: | ||
if new_agg is None: | ||
|
||
def f(c: str) -> str: | ||
return "/".join([window_size] + c.split("/")[1:]) | ||
|
||
else: | ||
|
||
def f(c: str) -> str: | ||
return "/".join([window_size] + c.split("/")[1:-1] + [new_agg]) | ||
|
||
return f | ||
|
||
# Columns to convert to counts: | ||
present_indicator_cols = cs.ends_with("/present") | ||
|
||
# Columns to convert to value aggregations: | ||
value_cols = cs.ends_with("/value") | ||
|
||
# Columns to aggregate via other operations | ||
cnt_cols = (cs.ends_with("/count") | cs.ends_with("/has_values_count")).fill_null(0) | ||
|
||
cols_to_sum = cs.ends_with("/sum") | cs.ends_with("/sum_sqd") | ||
cols_to_min = cs.ends_with("/min") | ||
cols_to_max = cs.ends_with("/max") | ||
|
||
if window_size == "FULL": | ||
df = df.groupby("subject_id").agg( | ||
"timestamp", | ||
# present to counts | ||
present_indicator_cols.cumsum().map_alias(time_aggd_col_alias_fntr("count")), | ||
# values to stats | ||
value_cols.is_not_null().cumsum().map_alias(time_aggd_col_alias_fntr("count")), | ||
( | ||
(value_cols.is_not_null() & value_cols.is_not_nan()) | ||
.cumsum() | ||
.map_alias(time_aggd_col_alias_fntr("has_values_count")) | ||
), | ||
value_cols.cumsum().map_alias(time_aggd_col_alias_fntr("sum")), | ||
(value_cols**2).cumsum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), | ||
value_cols.cummin().map_alias(time_aggd_col_alias_fntr("min")), | ||
value_cols.cummax().map_alias(time_aggd_col_alias_fntr("max")), | ||
# Raw aggregations | ||
cnt_cols.cumsum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_sum.cumsum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_min.cummin().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_max.cummax().map_alias(time_aggd_col_alias_fntr()), | ||
) | ||
df = df.explode(*[c for c in df.columns if c != "subject_id"]) | ||
else: | ||
df = df.groupby_rolling( | ||
index_column="timestamp", | ||
by="subject_id", | ||
period=window_size, | ||
).agg( | ||
# present to counts | ||
present_indicator_cols.sum().map_alias(time_aggd_col_alias_fntr("count")), | ||
# values to stats | ||
value_cols.is_not_null().sum().map_alias(time_aggd_col_alias_fntr("count")), | ||
( | ||
(value_cols.is_not_null() & value_cols.is_not_nan()) | ||
.sum() | ||
.map_alias(time_aggd_col_alias_fntr("has_values_count")) | ||
), | ||
value_cols.sum().map_alias(time_aggd_col_alias_fntr("sum")), | ||
(value_cols**2).sum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), | ||
value_cols.min().map_alias(time_aggd_col_alias_fntr("min")), | ||
value_cols.max().map_alias(time_aggd_col_alias_fntr("max")), | ||
# Raw aggregations | ||
cnt_cols.sum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_sum.sum().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_min.min().map_alias(time_aggd_col_alias_fntr()), | ||
cols_to_max.max().map_alias(time_aggd_col_alias_fntr()), | ||
) | ||
|
||
return _normalize_flat_rep_df_cols(df, set_count_0_to_null=True) | ||
|
||
|
||
def get_flat_ts_rep( | ||
feature_columns: list[str], | ||
**kwargs, | ||
) -> pl.LazyFrame: | ||
"""Produce raw representation for dynamic data.""" | ||
|
||
return _normalize_flat_rep_df_cols( | ||
_summarize_dynamic_measurements(feature_columns, **kwargs) | ||
.sort(by=["subject_id", "timestamp"]) | ||
.collect() | ||
.lazy(), | ||
[c for c in feature_columns if c.startswith("dynamic")], | ||
) | ||
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason... |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anything with a null timestamp should be assumed to be a static code. They won't all start with "STATIC_" (unless you've done some pre-processing here I'm missing.