Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #78

Merged
merged 21 commits into from
Aug 12, 2024
Merged

Dev #78

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bb79bf8
Updated configs to point to MEDS v0.3 proper directories.
mmcdermott Aug 10, 2024
cebd9ec
Updating output path organization.
mmcdermott Aug 10, 2024
d6a4876
Updated paths to point to MEDS compliant locations. Also fixed some b…
mmcdermott Aug 10, 2024
477e940
Added version dependency.
mmcdermott Aug 10, 2024
ecef0e2
updating data schema to be MEDS v0.3 compatible
teyaberg Aug 10, 2024
a6c4a17
Merge pull request #62 from mmcdermott/57_update_paths_for_MEDS
mmcdermott Aug 10, 2024
eae8b11
Merge pull request #64 from mmcdermott/main
mmcdermott Aug 10, 2024
490308a
Merge branch 'dev' into 55_MEDS_v03
mmcdermott Aug 10, 2024
8b7859c
Merge branch '55_MEDS_v03' into 59_add_meds_dependency
mmcdermott Aug 10, 2024
85b49c3
Merge pull request #63 from mmcdermott/59_add_meds_dependency
mmcdermott Aug 10, 2024
c9595f2
fixing doctest
teyaberg Aug 10, 2024
3314937
Added some starter code for these changes; more will be needed.
mmcdermott Aug 10, 2024
cf2a4e8
Made a bunch of changes mostly for #66 but tests are currently failing.
mmcdermott Aug 11, 2024
9f24efe
Fixed test error with label re-processing.
mmcdermott Aug 11, 2024
867b5c5
Merge pull request #71 from mmcdermott/58_add_reshard_stage
mmcdermott Aug 11, 2024
b01b1b2
Added MEDS-Transform version dependency.
mmcdermott Aug 11, 2024
df25d18
Basic documentation updates.
mmcdermott Aug 11, 2024
afa70e9
Fixing a lint issue
mmcdermott Aug 11, 2024
2410e09
lint issue
teyaberg Aug 11, 2024
5eb32ae
Change from bash to console.
mmcdermott Aug 11, 2024
884a4af
Merge pull request #65 from mmcdermott/55_MEDS_v03
Oufattole Aug 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ This repository consists of two key pieces:

## Quick Start

To use MEDS-Tab, install the dependencies following commands below:
To use MEDS-Tab, install the dependencies following commands below. Note that this version of MEDS-Tab is
compatible with [MEDS v0.3](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0)

**Pip Install**

Expand All @@ -44,10 +45,10 @@ pip install .

## Scripts and Examples

For an end-to-end example over MIMIC-IV, see the [MIMIC-IV companion repository](https://github.com/mmcdermott/MEDS_TAB_MIMIC_IV).
For an end-to-end example over Philips eICU, see the [eICU companion repository](https://github.com/mmcdermott/MEDS_TAB_EICU).
For an end to end example, including re-sharding the input via MEDS-Transforms, see
[this example script](https://gist.github.com/mmcdermott/34194e484d7b2a2f68967b9bbccfb35b)

See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm.
See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline (minus re-sharding) being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm.

## Why MEDS-Tab?

Expand All @@ -73,6 +74,28 @@ By following these steps, you can seamlessly transform your dataset, define nece

## Core CLI Scripts Overview

0. First, if your data is not already sharded to the degree you want and in a manner that subdivides your
splits with the format `"$SPLIT_NAME/\d+.parquet"`, where `$SPLIT_NAME` does not contain slashes, you will
need to re-shard your data. This can be done via the
[MEDS-Transforms](https://github.com/mmcdermott/MEDS_transforms) library, which is not included in this
repository. Having data sharded by split _is a necessary step_ to ensure that the data is efficiently
processed in parallel. You can easily re-shard your input MEDS cohort in the environment into which this
package is installed with the following command:

```console
# Re-shard pipeline
# $MIMICIV_MEDS_DIR is the directory containing the input, MEDS v0.3 formatted MIMIC-IV data
# $MEDS_TAB_COHORT_DIR is the directory where the re-sharded MEDS dataset will be stored, and where your model
# will store cached files during processing by default.
# $N_PATIENTS_PER_SHARD is the number of patients per shard you want to use.
MEDS_transform-reshard_to_split \
input_dir="$MIMICIV_MEDS_DIR" \
cohort_dir="$MEDS_TAB_COHORT_DIR" \
'stages=["reshard_to_split"]' \
stage="reshard_to_split" \
stage_configs.reshard_to_split.n_patients_per_shard=$N_PATIENTS_PER_SHARD
```

1. **`meds-tab-describe`**: This command processes MEDS data shards to compute the frequencies of different code types. It differentiates codes into the following categories:

- time-series codes (codes with timestamps)
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost", "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins"]
dependencies = [
"polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3",
"MEDS-transforms==0.0.4",
]

[project.scripts]
meds-tab-describe = "MEDS_tabular_automl.scripts.describe_codes:main"
Expand Down
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
MEDS_cohort_dir: ???
output_cohort_dir: ???
do_overwrite: False
seed: 1
tqdm: False
worker: 0
loguru_init: False

log_dir: ${output_dir}/.logs/
log_dir: ${output_cohort_dir}/.logs/
cache_dir: ${output_cohort_dir}/.cache

hydra:
verbose: False
Expand Down
9 changes: 2 additions & 7 deletions src/MEDS_tabular_automl/configs/describe_codes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ defaults:
- default
- _self_

# split we wish to get metadata for
split: train
# Raw data, must have a subdirectory "train" with the training data split
input_dir: ${MEDS_cohort_dir}/final_cohort/${split}
input_dir: ${output_cohort_dir}/data
# Where to store output code frequency data
cache_dir: ${MEDS_cohort_dir}/.cache
output_dir: ${MEDS_cohort_dir}
output_filepath: ${output_dir}/code_metadata.parquet
output_filepath: ${output_cohort_dir}/metadata/codes.parquet

name: describe_codes
17 changes: 6 additions & 11 deletions src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ defaults:
task_name: task

# Task cached data dir
input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${output_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${output_dir}/model_metadata.parquet
cache_dir: ${MEDS_cohort_dir}/.cache
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json

# Model parameters
model_params:
Expand All @@ -31,13 +30,9 @@ model_params:
keep_data_in_memory: True
binarize_task: True

hydra:
verbose: False
sweep:
dir: ${output_dir}/.logs/
run:
dir: ${output_dir}/.logs/
log_dir: ${model_dir}/.logs/

hydra:
# Optuna Sweeper
sweeper:
sampler:
Expand Down
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/configs/tabularization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ defaults:

# Raw data
# Where the code metadata is stored
input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet
input_dir: ${MEDS_cohort_dir}/final_cohort
output_dir: ${MEDS_cohort_dir}/tabularize
input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
input_dir: ${output_cohort_dir}/data
output_dir: ${output_cohort_dir}/tabularize

name: tabularization
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# User inputs
allowed_codes: null
min_code_inclusion_frequency: 10
filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet
filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet
window_sizes:
- "1d"
- "7d"
Expand Down
9 changes: 6 additions & 3 deletions src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ defaults:
task_name: task

# Tabularized Data
input_dir: ${MEDS_cohort_dir}/tabularize
input_dir: ${output_cohort_dir}/tabularize
# Where the labels are stored, with columns patient_id, timestamp, label
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels
# Where to output the task specific tabularized data
output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
output_dir: ${output_cohort_dir}/${task_name}/task_cache
output_label_dir: ${output_cohort_dir}/${task_name}/labels

label_column: "boolean_value"

name: task_specific_caching
52 changes: 26 additions & 26 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
>>> data = pl.DataFrame({
... 'patient_id': [1, 1, 2, 2, 3, 3, 3],
... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
... 'timestamp': [
... 'time': [
... None,
... datetime(2021, 1, 1),
... None,
Expand All @@ -91,7 +91,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... datetime(2021, 1, 4),
... None
... ],
... 'numerical_value': [1, None, 2, 2, None, None, 3]
... 'numeric_value': [1, None, 2, 2, None, None, 3]
... }).lazy()
>>> assert (
... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == {
Expand All @@ -101,29 +101,29 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... )
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_null()
)
static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect()
static_code_freqs = {
row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True)
}

static_value_df = static_df.filter(pl.col("numerical_value").is_not_null())
static_value_df = static_df.filter(pl.col("numeric_value").is_not_null())
static_value_freqs_df = (
static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
static_value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
)
static_value_freqs = {
row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True)
}

ts_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_not_null()
)
code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect()
code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)}

value_df = ts_df.filter(pl.col("numerical_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
value_df = ts_df.filter(pl.col("numeric_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)}

combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs}
Expand Down Expand Up @@ -222,23 +222,23 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... "time": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numeric_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────
│ code ┆ timestamp ┆ numerical_value
│ --- ┆ --- ┆ ---
│ str ┆ str ┆ i64
╞══════╪════════════╪═════════════════
│ A ┆ 2021-01-01 ┆ null
│ A ┆ 2021-01-01 ┆ null
│ D ┆ null ┆ null
│ D ┆ null ┆ null
│ E ┆ 2021-01-03 ┆ null
│ E ┆ 2021-01-04 ┆ 3
└──────┴────────────┴─────────────────
┌──────┬────────────┬───────────────┐
│ code ┆ time ┆ numeric_value
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═══════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴───────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
Expand All @@ -257,8 +257,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]

is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
is_static_code = pl.col("time").is_null()
is_numeric_code = pl.col("numeric_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
Expand All @@ -268,8 +268,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
.otherwise(pl.col("numeric_value"))
.alias("numeric_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
Expand Down
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def summarize_static_measurements(
code_subset = df.filter(pl.col("code").is_in(static_first_codes))
first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect()
static_value_pivot_df = first_code_subset.pivot(
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None
index=["patient_id"], columns=["code"], values=["numeric_value"], aggregate_function=None
)
# rename code to feature name
remap_cols = {
Expand Down
10 changes: 5 additions & 5 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_rolling_window_indicies(index_df: pl.LazyFrame, window_size: str) -> pl.
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.rolling(index_column="time", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
Expand Down Expand Up @@ -133,11 +133,11 @@ def compute_agg(
"""Applies aggregation to a sparse matrix using rolling window indices derived from a DataFrame.

Dataframe is expected to only have the relevant columns for aggregating. It should have the patient_id and
timestamp columns, and then only code columns if agg is a code aggregation or only value columns if it is
time columns, and then only code columns if agg is a code aggregation or only value columns if it is
a value aggreagation.

Args:
index_df: The DataFrame with 'patient_id' and 'timestamp' columns used for grouping.
index_df: The DataFrame with 'patient_id' and 'time' columns used for grouping.
matrix: The sparse matrix to be aggregated.
window_size: The string defining the rolling window size.
agg: The string specifying the aggregation method.
Expand All @@ -149,11 +149,11 @@ def compute_agg(
"""
group_df = (
index_df.with_row_index("index")
.group_by(["patient_id", "timestamp"], maintain_order=True)
.group_by(["patient_id", "time"], maintain_order=True)
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.collect()
)
index_df = group_df.lazy().select(pl.col("patient_id", "timestamp"))
index_df = group_df.lazy().select(pl.col("patient_id", "time"))
windows = group_df.select(pl.col("min_index", "max_index"))
logger.info("Step 1.5: Running sparse aggregation.")
matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm)
Expand Down
14 changes: 6 additions & 8 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_long_code_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_)
return data, (rows, cols)

Expand All @@ -76,9 +76,7 @@ def get_long_value_df(
the CSR sparse matrix.
"""
column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)}
value_df = (
df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns))
)
value_df = df.with_row_index("index").drop_nulls("numeric_value").filter(pl.col("code").is_in(ts_columns))
rows = value_df.select(pl.col("index")).collect().to_series().to_numpy()
cols = (
value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index"))
Expand All @@ -87,8 +85,8 @@ def get_long_value_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy()
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = value_df.select(pl.col("numeric_value")).collect().to_series().to_numpy()
return data, (rows, cols)


Expand All @@ -109,15 +107,15 @@ def summarize_dynamic_measurements(
of aggregated values.
"""
logger.info("Generating Sparse matrix for Time Series Features")
id_cols = ["patient_id", "timestamp"]
id_cols = ["patient_id", "time"]

# Confirm dataframe is sorted
check_df = df.select(pl.col(id_cols))
assert check_df.sort(by=id_cols).collect().equals(check_df.collect()), "data frame must be sorted"

# Generate sparse matrix
if agg in CODE_AGGREGATIONS:
code_df = df.drop(*(id_cols + ["numerical_value"]))
code_df = df.drop(*(id_cols + ["numeric_value"]))
data, (rows, cols) = get_long_code_df(code_df, ts_columns)
elif agg in VALUE_AGGREGATIONS:
value_df = df.drop(*id_cols)
Expand Down
Loading
Loading