Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE. Temporary. #72

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ This repository consists of two key pieces:

## Quick Start

To use MEDS-Tab, install the dependencies following commands below:
To use MEDS-Tab, install the dependencies following commands below. Note that this version of MEDS-Tab is
compatible with [MEDS v0.3](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0)

**Pip Install**

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost", "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins"]
dependencies = [
"polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3",
#"MEDS-transforms==0.0.4",
]

[project.scripts]
meds-tab-describe = "MEDS_tabular_automl.scripts.describe_codes:main"
Expand Down
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
MEDS_cohort_dir: ???
output_cohort_dir: ???
do_overwrite: False
seed: 1
tqdm: False
worker: 0
loguru_init: False

log_dir: ${output_dir}/.logs/
log_dir: ${output_cohort_dir}/.logs/
cache_dir: ${output_cohort_dir}/.cache

hydra:
verbose: False
Expand Down
9 changes: 2 additions & 7 deletions src/MEDS_tabular_automl/configs/describe_codes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ defaults:
- default
- _self_

# split we wish to get metadata for
split: train
# Raw data, must have a subdirectory "train" with the training data split
input_dir: ${MEDS_cohort_dir}/final_cohort/${split}
input_dir: ${output_cohort_dir}/data
# Where to store output code frequency data
cache_dir: ${MEDS_cohort_dir}/.cache
output_dir: ${MEDS_cohort_dir}
output_filepath: ${output_dir}/code_metadata.parquet
output_filepath: ${output_cohort_dir}/metadata/codes.parquet

name: describe_codes
17 changes: 6 additions & 11 deletions src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ defaults:
task_name: task

# Task cached data dir
input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${output_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${output_dir}/model_metadata.parquet
cache_dir: ${MEDS_cohort_dir}/.cache
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json

# Model parameters
model_params:
Expand All @@ -31,13 +30,9 @@ model_params:
keep_data_in_memory: True
binarize_task: True

hydra:
verbose: False
sweep:
dir: ${output_dir}/.logs/
run:
dir: ${output_dir}/.logs/
log_dir: ${model_dir}/.logs/

hydra:
# Optuna Sweeper
sweeper:
sampler:
Expand Down
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/configs/tabularization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ defaults:

# Raw data
# Where the code metadata is stored
input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet
input_dir: ${MEDS_cohort_dir}/final_cohort
output_dir: ${MEDS_cohort_dir}/tabularize
input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
input_dir: ${output_cohort_dir}/data
output_dir: ${output_cohort_dir}/tabularize

name: tabularization
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# User inputs
allowed_codes: null
min_code_inclusion_frequency: 10
filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet
filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet
window_sizes:
- "1d"
- "7d"
Expand Down
9 changes: 6 additions & 3 deletions src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ defaults:
task_name: task

# Tabularized Data
input_dir: ${MEDS_cohort_dir}/tabularize
input_dir: ${output_cohort_dir}/tabularize
# Where the labels are stored, with columns patient_id, timestamp, label
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels
# Where to output the task specific tabularized data
output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
output_dir: ${output_cohort_dir}/${task_name}/task_cache
output_label_dir: ${output_cohort_dir}/${task_name}/labels

label_column: "boolean_value"

name: task_specific_caching
52 changes: 26 additions & 26 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
>>> data = pl.DataFrame({
... 'patient_id': [1, 1, 2, 2, 3, 3, 3],
... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
... 'timestamp': [
... 'time': [
... None,
... datetime(2021, 1, 1),
... None,
Expand All @@ -91,7 +91,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... datetime(2021, 1, 4),
... None
... ],
... 'numerical_value': [1, None, 2, 2, None, None, 3]
... 'numeric_value': [1, None, 2, 2, None, None, 3]
... }).lazy()
>>> assert (
... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == {
Expand All @@ -101,29 +101,29 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... )
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_null()
)
static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect()
static_code_freqs = {
row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True)
}

static_value_df = static_df.filter(pl.col("numerical_value").is_not_null())
static_value_df = static_df.filter(pl.col("numeric_value").is_not_null())
static_value_freqs_df = (
static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
static_value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
)
static_value_freqs = {
row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True)
}

ts_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_not_null()
)
code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect()
code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)}

value_df = ts_df.filter(pl.col("numerical_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
value_df = ts_df.filter(pl.col("numeric_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)}

combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs}
Expand Down Expand Up @@ -222,23 +222,23 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... "time": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numeric_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────
│ code ┆ timestamp ┆ numerical_value
│ --- ┆ --- ┆ ---
│ str ┆ str ┆ i64
╞══════╪════════════╪═════════════════
│ A ┆ 2021-01-01 ┆ null
│ A ┆ 2021-01-01 ┆ null
│ D ┆ null ┆ null
│ D ┆ null ┆ null
│ E ┆ 2021-01-03 ┆ null
│ E ┆ 2021-01-04 ┆ 3
└──────┴────────────┴─────────────────
┌──────┬────────────┬───────────────┐
│ code ┆ time ┆ numeric_value
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═══════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴───────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
Expand All @@ -257,8 +257,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]

is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
is_static_code = pl.col("time").is_null()
is_numeric_code = pl.col("numeric_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
Expand All @@ -268,8 +268,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
.otherwise(pl.col("numeric_value"))
.alias("numeric_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
Expand Down
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def summarize_static_measurements(
code_subset = df.filter(pl.col("code").is_in(static_first_codes))
first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect()
static_value_pivot_df = first_code_subset.pivot(
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None
index=["patient_id"], columns=["code"], values=["numeric_value"], aggregate_function=None
)
# rename code to feature name
remap_cols = {
Expand Down
10 changes: 5 additions & 5 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_rolling_window_indicies(index_df: pl.LazyFrame, window_size: str) -> pl.
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.rolling(index_column="time", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
Expand Down Expand Up @@ -133,11 +133,11 @@ def compute_agg(
"""Applies aggregation to a sparse matrix using rolling window indices derived from a DataFrame.

Dataframe is expected to only have the relevant columns for aggregating. It should have the patient_id and
timestamp columns, and then only code columns if agg is a code aggregation or only value columns if it is
time columns, and then only code columns if agg is a code aggregation or only value columns if it is
a value aggreagation.

Args:
index_df: The DataFrame with 'patient_id' and 'timestamp' columns used for grouping.
index_df: The DataFrame with 'patient_id' and 'time' columns used for grouping.
matrix: The sparse matrix to be aggregated.
window_size: The string defining the rolling window size.
agg: The string specifying the aggregation method.
Expand All @@ -149,11 +149,11 @@ def compute_agg(
"""
group_df = (
index_df.with_row_index("index")
.group_by(["patient_id", "timestamp"], maintain_order=True)
.group_by(["patient_id", "time"], maintain_order=True)
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.collect()
)
index_df = group_df.lazy().select(pl.col("patient_id", "timestamp"))
index_df = group_df.lazy().select(pl.col("patient_id", "time"))
windows = group_df.select(pl.col("min_index", "max_index"))
logger.info("Step 1.5: Running sparse aggregation.")
matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm)
Expand Down
14 changes: 6 additions & 8 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_long_code_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_)
return data, (rows, cols)

Expand All @@ -76,9 +76,7 @@ def get_long_value_df(
the CSR sparse matrix.
"""
column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)}
value_df = (
df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns))
)
value_df = df.with_row_index("index").drop_nulls("numeric_value").filter(pl.col("code").is_in(ts_columns))
rows = value_df.select(pl.col("index")).collect().to_series().to_numpy()
cols = (
value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index"))
Expand All @@ -87,8 +85,8 @@ def get_long_value_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy()
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = value_df.select(pl.col("numeric_value")).collect().to_series().to_numpy()
return data, (rows, cols)


Expand All @@ -109,15 +107,15 @@ def summarize_dynamic_measurements(
of aggregated values.
"""
logger.info("Generating Sparse matrix for Time Series Features")
id_cols = ["patient_id", "timestamp"]
id_cols = ["patient_id", "time"]

# Confirm dataframe is sorted
check_df = df.select(pl.col(id_cols))
assert check_df.sort(by=id_cols).collect().equals(check_df.collect()), "data frame must be sorted"

# Generate sparse matrix
if agg in CODE_AGGREGATIONS:
code_df = df.drop(*(id_cols + ["numerical_value"]))
code_df = df.drop(*(id_cols + ["numeric_value"]))
data, (rows, cols) = get_long_code_df(code_df, ts_columns)
elif agg in VALUE_AGGREGATIONS:
value_df = df.drop(*id_cols)
Expand Down
Loading
Loading