Skip to content

Commit

Permalink
Made a bunch of changes mostly for #66 but tests are currently failing.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 11, 2024
1 parent 3314937 commit cf2a4e8
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ task_name: task
# Task cached data dir
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/
input_label_dir: ${output_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json
Expand Down
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/configs/tabularization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:
# Raw data
# Where the code metadata is stored
input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
input_dir: ${output_cohort_data}/data
input_dir: ${output_cohort_dir}/data
output_dir: ${output_cohort_dir}/tabularize

name: tabularization
3 changes: 3 additions & 0 deletions src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ input_dir: ${output_cohort_dir}/tabularize
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels
# Where to output the task specific tabularized data
output_dir: ${output_cohort_dir}/${task_name}/task_cache
output_label_dir: ${output_cohort_dir}/${task_name}/labels

label_column: "boolean_value"

name: task_specific_caching
51 changes: 36 additions & 15 deletions src/MEDS_tabular_automl/scripts/cache_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python

"""Aggregates time-series data for feature columns across different window sizes."""
from functools import partial
from importlib.resources import files
from pathlib import Path

Expand All @@ -10,14 +11,17 @@
import scipy.sparse as sp
from omegaconf import DictConfig

from ..describe_codes import filter_parquet, get_feature_columns
from ..file_name import list_subdir_files
from ..mapper import wrap as rwlock_wrap
from ..utils import (
CODE_AGGREGATIONS,
STATIC_CODE_AGGREGATION,
STATIC_VALUE_AGGREGATION,
VALUE_AGGREGATIONS,
get_events_df,
get_shard_prefix,
get_unique_time_events_df,
hydra_loguru_init,
load_matrix,
load_tqdm,
Expand All @@ -37,6 +41,10 @@
]


def write_lazyframe(df: pl.LazyFrame, fp: Path):
df.collect().write_parquet(fp, use_pyarrow=True)


def generate_row_cached_matrix(matrix: sp.coo_array, label_df: pl.LazyFrame) -> sp.coo_array:
"""Generates row-cached matrix for a given matrix and label DataFrame.
Expand Down Expand Up @@ -80,31 +88,44 @@ def main(cfg: DictConfig):
tabularization_tasks = list_subdir_files(cfg.input_dir, "npz")
np.random.shuffle(tabularization_tasks)

label_dir = Path(cfg.input_label_dir)
label_df = pl.scan_parquet(label_dir / "**/*.parquet").rename({"prediction_time": "time"})

feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp)

# iterate through them
for data_fp in iter_wrapper(tabularization_tasks):
# parse as time series agg
split, shard_num, window_size, code_type, agg_name = Path(data_fp).with_suffix("").parts[-5:]
label_fp = Path(cfg.input_label_dir) / split / f"{shard_num}.parquet"
out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz")
assert label_fp.exists(), f"Output file {label_fp} does not exist."

def read_fn(fps):
matrix_fp, label_fp = fps
return load_matrix(matrix_fp), pl.scan_parquet(label_fp)
raw_data_fp = Path(cfg.output_cohort_dir) / "data" / split / f"{shard_num}.parquet"
raw_data_df = filter_parquet(raw_data_fp, cfg.tabularization._resolved_codes)
raw_data_df = (
get_unique_time_events_df(get_events_df(raw_data_df, feature_columns))
.with_row_index("event_id")
.select("patient_id", "time", "event_id")
)
shard_label_df = label_df.join_asof(other=raw_data_df, by="patient_id", on="time")

def compute_fn(shard_dfs):
matrix, label_df = shard_dfs
cache_matrix = generate_row_cached_matrix(matrix, label_df)
return cache_matrix
shard_label_fp = Path(cfg.output_label_dir) / split / f"{shard_num}.parquet"
rwlock_wrap(
raw_data_fp,
shard_label_fp,
pl.scan_parquet,
write_lazyframe,
lambda df: shard_label_df,
do_overwrite=cfg.do_overwrite,
do_return=False,
)

def write_fn(cache_matrix, out_fp):
write_df(cache_matrix, out_fp, do_overwrite=cfg.do_overwrite)
out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz")
compute_fn = partial(generate_row_cached_matrix, label_df=shard_label_df)
write_fn = partial(write_df, do_overwrite=cfg.do_overwrite)

in_fps = [data_fp, label_fp]
rwlock_wrap(
in_fps,
data_fp,
out_fp,
read_fn,
load_matrix,
write_fn,
compute_fn,
do_overwrite=cfg.do_overwrite,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ def test_integration():
for split, data in MEDS_OUTPUTS.items():
file_path = output_cohort_dir / "data" / f"{split}.parquet"
file_path.parent.mkdir(exist_ok=True)
df = pl.read_csv(StringIO(data))
df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path)
df = pl.read_csv(StringIO(data)).with_columns(
pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")
)
df.write_parquet(file_path)
all_data.append(df)

all_data = pl.concat(all_data, how="diagnoal_relaxed")
all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"])

# Check the files are not empty
meds_files = list_subdir_files(Path(cfg.input_dir), "parquet")
Expand Down Expand Up @@ -210,7 +212,7 @@ def test_integration():
overrides = [f"{k}={v}" for k, v in cache_config.items()]
cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml

df = get_unique_time_events_df(get_events_df(all_data, feature_columns)).collect()
df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect()
pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]])
df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels))
df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value")
Expand Down
10 changes: 6 additions & 4 deletions tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def test_tabularize():
for split, data in MEDS_OUTPUTS.items():
file_path = output_cohort_dir / "data" / f"{split}.parquet"
file_path.parent.mkdir(exist_ok=True)
df = pl.read_csv(StringIO(data))
df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path)
df = pl.read_csv(StringIO(data)).with_columns(
pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")
)
df.write_parquet(file_path)
all_data.append(df)

all_data = pl.concat(all_data, how="diagnoal_relaxed")
all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"])

# Check the files are not empty
meds_files = list_subdir_files(Path(cfg.input_dir), "parquet")
Expand Down Expand Up @@ -297,7 +299,7 @@ def test_tabularize():
cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml

# Create fake labels
df = get_unique_time_events_df(get_events_df(all_data, feature_columns)).collect()
df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect()
pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]])
df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels))
df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value")
Expand Down

0 comments on commit cf2a4e8

Please sign in to comment.