Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE. Temporary. #72

Closed
wants to merge 14 commits into from
Closed
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
MEDS_cohort_dir: ???
output_cohort_dir: ???
do_overwrite: False
seed: 1
tqdm: False
worker: 0
loguru_init: False

log_dir: ${output_dir}/.logs/
log_dir: ${output_cohort_dir}/.logs/
cache_dir: ${output_cohort_dir}/.cache

hydra:
verbose: False
9 changes: 2 additions & 7 deletions src/MEDS_tabular_automl/configs/describe_codes.yaml
Original file line number Diff line number Diff line change
@@ -2,13 +2,8 @@ defaults:
- default
- _self_

# split we wish to get metadata for
split: train
# Raw data, must have a subdirectory "train" with the training data split
input_dir: ${MEDS_cohort_dir}/final_cohort/${split}
input_dir: ${MEDS_cohort_dir}/data
# Where to store output code frequency data
cache_dir: ${MEDS_cohort_dir}/.cache
output_dir: ${MEDS_cohort_dir}
output_filepath: ${output_dir}/code_metadata.parquet
output_filepath: ${output_cohort_dir}/metadata/codes.parquet

name: describe_codes
17 changes: 6 additions & 11 deletions src/MEDS_tabular_automl/configs/launch_xgboost.yaml
Original file line number Diff line number Diff line change
@@ -9,13 +9,12 @@ defaults:
task_name: task

# Task cached data dir
input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
input_dir: ${output_cohort_dir}/${task_name}/task_cache
# Directory with task labels
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/
# Where to output the model and cached data
output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${output_dir}/model_metadata.parquet
cache_dir: ${MEDS_cohort_dir}/.cache
model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S}
output_filepath: ${model_dir}/model_metadata.json

# Model parameters
model_params:
@@ -31,13 +30,9 @@ model_params:
keep_data_in_memory: True
binarize_task: True

hydra:
verbose: False
sweep:
dir: ${output_dir}/.logs/
run:
dir: ${output_dir}/.logs/
log_dir: ${model_dir}/.logs/

hydra:
# Optuna Sweeper
sweeper:
sampler:
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/configs/tabularization.yaml
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@ defaults:

# Raw data
# Where the code metadata is stored
input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet
input_dir: ${MEDS_cohort_dir}/final_cohort
output_dir: ${MEDS_cohort_dir}/tabularize
input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet
input_dir: ${MEDS_cohort_dir}/data
output_dir: ${output_cohort_dir}/tabularize

name: tabularization
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# User inputs
allowed_codes: null
min_code_inclusion_frequency: 10
filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet
filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet
window_sizes:
- "1d"
- "7d"
6 changes: 3 additions & 3 deletions src/MEDS_tabular_automl/configs/task_specific_caching.yaml
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@ defaults:
task_name: task

# Tabularized Data
input_dir: ${MEDS_cohort_dir}/tabularize
input_dir: ${output_cohort_dir}/tabularize
# Where the labels are stored, with columns patient_id, timestamp, label
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort
input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels
# Where to output the task specific tabularized data
output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache
output_dir: ${output_cohort_dir}/${task_name}/task_cache

name: task_specific_caching
30 changes: 15 additions & 15 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
>>> data = pl.DataFrame({
... 'patient_id': [1, 1, 2, 2, 3, 3, 3],
... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
... 'timestamp': [
... 'time': [
... None,
... datetime(2021, 1, 1),
... None,
@@ -91,7 +91,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... datetime(2021, 1, 4),
... None
... ],
... 'numerical_value': [1, None, 2, 2, None, None, 3]
... 'numeric_value': [1, None, 2, 2, None, None, 3]
... }).lazy()
>>> assert (
... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == {
@@ -101,29 +101,29 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
... )
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_null()
)
static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect()
static_code_freqs = {
row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True)
}

static_value_df = static_df.filter(pl.col("numerical_value").is_not_null())
static_value_df = static_df.filter(pl.col("numeric_value").is_not_null())
static_value_freqs_df = (
static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
static_value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
)
static_value_freqs = {
row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True)
}

ts_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null()
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_not_null()
)
code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect()
code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)}

value_df = ts_df.filter(pl.col("numerical_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect()
value_df = ts_df.filter(pl.col("numeric_value").is_not_null())
value_freqs_df = value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect()
value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)}

combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs}
@@ -222,13 +222,13 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... "time": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numeric_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value
│ code ┆ time ┆ numeric_value
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
@@ -257,8 +257,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes)
]

is_static_code = pl.col("timestamp").is_null()
is_numeric_code = pl.col("numerical_value").is_not_null()
is_static_code = pl.col("time").is_null()
is_numeric_code = pl.col("numeric_value").is_not_null()
rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns)
rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns)
rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code
@@ -268,8 +268,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
df = df.with_columns(
pl.when(rare_static_value | rare_ts_value)
.then(None)
.otherwise(pl.col("numerical_value"))
.alias("numerical_value")
.otherwise(pl.col("numeric_value"))
.alias("numeric_value")
)
# Drop rows with rare codes
df = df.filter(~(rare_static_code | rare_ts_code))
2 changes: 1 addition & 1 deletion src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
@@ -119,7 +119,7 @@ def summarize_static_measurements(
code_subset = df.filter(pl.col("code").is_in(static_first_codes))
first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect()
static_value_pivot_df = first_code_subset.pivot(
index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None
index=["patient_id"], columns=["code"], values=["numeric_value"], aggregate_function=None
)
# rename code to feature name
remap_cols = {
10 changes: 5 additions & 5 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def get_rolling_window_indicies(index_df: pl.LazyFrame, window_size: str) -> pl.
timedelta = pd.Timedelta(window_size)
return (
index_df.with_row_index("index")
.rolling(index_column="timestamp", period=timedelta, group_by="patient_id")
.rolling(index_column="time", period=timedelta, group_by="patient_id")
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.select(pl.col("min_index", "max_index"))
.collect()
@@ -133,11 +133,11 @@ def compute_agg(
"""Applies aggregation to a sparse matrix using rolling window indices derived from a DataFrame.

Dataframe is expected to only have the relevant columns for aggregating. It should have the patient_id and
timestamp columns, and then only code columns if agg is a code aggregation or only value columns if it is
time columns, and then only code columns if agg is a code aggregation or only value columns if it is
a value aggreagation.

Args:
index_df: The DataFrame with 'patient_id' and 'timestamp' columns used for grouping.
index_df: The DataFrame with 'patient_id' and 'time' columns used for grouping.
matrix: The sparse matrix to be aggregated.
window_size: The string defining the rolling window size.
agg: The string specifying the aggregation method.
@@ -149,11 +149,11 @@ def compute_agg(
"""
group_df = (
index_df.with_row_index("index")
.group_by(["patient_id", "timestamp"], maintain_order=True)
.group_by(["patient_id", "time"], maintain_order=True)
.agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")])
.collect()
)
index_df = group_df.lazy().select(pl.col("patient_id", "timestamp"))
index_df = group_df.lazy().select(pl.col("patient_id", "time"))
windows = group_df.select(pl.col("min_index", "max_index"))
logger.info("Step 1.5: Running sparse aggregation.")
matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm)
14 changes: 6 additions & 8 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ def get_long_code_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_)
return data, (rows, cols)

@@ -76,9 +76,7 @@ def get_long_value_df(
the CSR sparse matrix.
"""
column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)}
value_df = (
df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns))
)
value_df = df.with_row_index("index").drop_nulls("numeric_value").filter(pl.col("code").is_in(ts_columns))
rows = value_df.select(pl.col("index")).collect().to_series().to_numpy()
cols = (
value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index"))
@@ -87,8 +85,8 @@ def get_long_value_df(
.to_series()
.to_numpy()
)
assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type"
data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy()
assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type"
data = value_df.select(pl.col("numeric_value")).collect().to_series().to_numpy()
return data, (rows, cols)


@@ -109,15 +107,15 @@ def summarize_dynamic_measurements(
of aggregated values.
"""
logger.info("Generating Sparse matrix for Time Series Features")
id_cols = ["patient_id", "timestamp"]
id_cols = ["patient_id", "time"]

# Confirm dataframe is sorted
check_df = df.select(pl.col(id_cols))
assert check_df.sort(by=id_cols).collect().equals(check_df.collect()), "data frame must be sorted"

# Generate sparse matrix
if agg in CODE_AGGREGATIONS:
code_df = df.drop(*(id_cols + ["numerical_value"]))
code_df = df.drop(*(id_cols + ["numeric_value"]))
data, (rows, cols) = get_long_code_df(code_df, ts_columns)
elif agg in VALUE_AGGREGATIONS:
value_df = df.drop(*id_cols)
11 changes: 1 addition & 10 deletions src/MEDS_tabular_automl/scripts/describe_codes.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
import numpy as np
import polars as pl
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig

from ..describe_codes import (
compute_feature_frequencies,
@@ -36,15 +36,6 @@ def main(cfg: DictConfig):
if not cfg.loguru_init:
hydra_loguru_init()

# Store Config
output_dir = Path(cfg.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
OmegaConf.save(cfg, output_dir / "config.yaml")

# Create output dir
input_dir = Path(cfg.input_dir)
input_dir.mkdir(exist_ok=True, parents=True)

# 0. Identify Output Columns and Frequencies
logger.info("Iterating through shards and caching feature frequencies.")

9 changes: 4 additions & 5 deletions src/MEDS_tabular_automl/scripts/launch_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Callable, Mapping
from datetime import datetime
from importlib.resources import files
from pathlib import Path

@@ -440,10 +439,10 @@ def main(cfg: DictConfig) -> float:
# print("Held Out Iterator Time: \n", model.iheld_out._profile_durations())

# save model
save_dir = Path(cfg.output_dir)
save_dir.mkdir(parents=True, exist_ok=True)
model_time = datetime.now().strftime("%H%M%S%f")
model.model.save_model(save_dir / f"{auc:.4f}_model_{model_time}.json")
output_fp = Path(cfg.output_filepath)
output_fp.parent.mkdir(parents=True, exist_ok=True)

model.model.save_model(output_fp)
except Exception as e:
logger.error(f"Error occurred: {e}")
auc = 0.0
18 changes: 8 additions & 10 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
@@ -284,7 +284,7 @@ def write_df(df: pl.LazyFrame | pl.DataFrame | coo_array, fp: Path, do_overwrite


def get_events_df(shard_df: pl.LazyFrame, feature_columns) -> pl.LazyFrame:
"""Extracts and filters an Events LazyFrame with one row per observation (timestamps can be duplicated).
"""Extracts and filters an Events LazyFrame with one row per observation (times can be duplicated).

Args:
shard_df: The LazyFrame shard from which to extract events.
@@ -296,28 +296,26 @@ def get_events_df(shard_df: pl.LazyFrame, feature_columns) -> pl.LazyFrame:
# Filter out feature_columns that were not present in the training set
raw_feature_columns = ["/".join(c.split("/")[:-1]) for c in feature_columns]
shard_df = shard_df.filter(pl.col("code").is_in(raw_feature_columns))
# Drop rows with missing timestamp or code to get events
ts_shard_df = shard_df.drop_nulls(subset=["timestamp", "code"])
# Drop rows with missing time or code to get events
ts_shard_df = shard_df.drop_nulls(subset=["time", "code"])
return ts_shard_df


def get_unique_time_events_df(events_df: pl.LazyFrame) -> pl.LazyFrame:
"""Ensures all timestamps in the events LazyFrame are unique and sorted by patient_id and timestamp.
"""Ensures all times in the events LazyFrame are unique and sorted by patient_id and time.

Args:
events_df: Events LazyFrame to process.

Returns:
A LazyFrame with unique timestamps, sorted by patient_id and timestamp.
A LazyFrame with unique times, sorted by patient_id and time.
"""
assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0
assert events_df.select(pl.col("time")).null_count().collect().item() == 0
# Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline
events_df = (
events_df.drop_nulls("timestamp")
.select(pl.col(["patient_id", "timestamp"]))
.unique(maintain_order=True)
events_df.drop_nulls("time").select(pl.col(["patient_id", "time"])).unique(maintain_order=True)
)
assert events_df.sort(by=["patient_id", "timestamp"]).collect().equals(events_df.collect())
assert events_df.sort(by=["patient_id", "time"]).collect().equals(events_df.collect())
return events_df


44 changes: 15 additions & 29 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -45,34 +45,36 @@ def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test
def test_integration():
# Step 0: Setup Environment
with tempfile.TemporaryDirectory() as d:
MEDS_cohort_dir = Path(d) / "processed"
MEDS_cohort_dir = Path(d) / "MEDS_cohort_dir"
output_cohort_dir = Path(d) / "output_cohort_dir"

describe_codes_config = {
shared_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"output_cohort_dir": str(output_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
}

describe_codes_config = {**shared_config}

with initialize(
version_base=None, config_path="../src/MEDS_tabular_automl/configs/"
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in describe_codes_config.items()]
cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml

# Create the directories
(MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True)
(MEDS_cohort_dir / "data").mkdir(parents=True, exist_ok=True)

# Store MEDS outputs
for split, data in MEDS_OUTPUTS.items():
file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet"
file_path = MEDS_cohort_dir / "data" / f"{split}.parquet"
file_path.parent.mkdir(exist_ok=True)
df = pl.read_csv(StringIO(data))
df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(
file_path
)
df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path)

# Check the files are not empty
meds_files = list_subdir_files(Path(cfg.input_dir), "parquet")
@@ -92,7 +94,6 @@ def test_integration():
describe_codes_config,
"describe_codes",
)
assert (Path(cfg.output_dir) / "config.yaml").is_file()
assert Path(cfg.output_filepath).is_file()

feature_columns = get_feature_columns(cfg.output_filepath)
@@ -104,12 +105,7 @@ def test_integration():

# Step 2: Run the static data tabularization script
tabularize_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -158,12 +154,7 @@ def test_integration():

# Step 3: Run the time series tabularization script
tabularize_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -205,12 +196,7 @@ def test_integration():
)
# Step 4: Run the task_specific_caching script
cache_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -220,18 +206,18 @@ def test_integration():
overrides = [f"{k}={v}" for k, v in cache_config.items()]
cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml
# Create fake labels
for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"):
for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "data", "parquet"):
df = pl.scan_parquet(f)
df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect()
pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]])
df = df.with_columns(pl.Series(name="label", values=pseudo_labels))
df = df.select(pl.col(["patient_id", "timestamp", "label"]))
df = df.select(pl.col(["patient_id", "time", "label"]))
df = df.with_row_index("event_id")

split = f.parent.stem
shard_num = f.stem
out_f = Path(cfg.input_label_dir) / Path(
get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f)
get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "data", f)
).with_suffix(".parquet")
out_f.parent.mkdir(parents=True, exist_ok=True)
df.write_parquet(out_f)
69 changes: 30 additions & 39 deletions tests/test_tabularize.py
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@
SPLITS_JSON = """{"train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733]}""" # noqa: E501

MEDS_TRAIN_0 = """
patient_id,code,timestamp,numerical_value
patient_id,code,time,numeric_value
239684,HEIGHT,,175.271115221764
239684,EYE_COLOR//BROWN,,
239684,DOB,1980-12-28T00:00:00.000000,
@@ -68,7 +68,7 @@
1195293,DISCHARGE,2010-06-20T20:50:04.000000,
"""
MEDS_TRAIN_1 = """
patient_id,code,timestamp,numerical_value
patient_id,code,time,numeric_value
68729,EYE_COLOR//HAZEL,,
68729,HEIGHT,,160.3953106166676
68729,DOB,1978-03-09T00:00:00.000000,
@@ -85,7 +85,7 @@
814703,DISCHARGE,2010-02-05T07:02:30.000000,
"""
MEDS_HELD_OUT_0 = """
patient_id,code,timestamp,numerical_value
patient_id,code,time,numeric_value
1500733,HEIGHT,,158.60131573580904
1500733,EYE_COLOR//BROWN,,
1500733,DOB,1986-07-20T00:00:00.000000,
@@ -99,7 +99,7 @@
1500733,DISCHARGE,2010-06-03T16:44:26.000000,
"""
MEDS_TUNING_0 = """
patient_id,code,timestamp,numerical_value
patient_id,code,time,numeric_value
754281,EYE_COLOR//BROWN,,
754281,HEIGHT,,166.22261567137025
754281,DOB,1988-12-19T00:00:00.000000,
@@ -148,34 +148,36 @@

def test_tabularize():
with tempfile.TemporaryDirectory() as d:
MEDS_cohort_dir = Path(d) / "processed"
MEDS_cohort_dir = Path(d) / "MEDS_cohort_dir"
output_cohort_dir = Path(d) / "output_cohort_dir"

describe_codes_config = {
shared_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"output_cohort_dir": str(output_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
}

describe_codes_config = {**shared_config}

with initialize(
version_base=None, config_path="../src/MEDS_tabular_automl/configs/"
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in describe_codes_config.items()]
cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml

# Create the directories
(MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True)
(MEDS_cohort_dir / "data").mkdir(parents=True, exist_ok=True)

# Store MEDS outputs
for split, data in MEDS_OUTPUTS.items():
file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet"
file_path = MEDS_cohort_dir / "data" / f"{split}.parquet"
file_path.parent.mkdir(exist_ok=True)
df = pl.read_csv(StringIO(data))
df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(
file_path
)
df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path)

# Check the files are not empty
meds_files = list_subdir_files(Path(cfg.input_dir), "parquet")
@@ -190,7 +192,6 @@ def test_tabularize():
# Step 1: Describe Codes - compute code frequencies
describe_codes.main(cfg)

assert (Path(cfg.output_dir) / "config.yaml").is_file()
assert Path(cfg.output_filepath).is_file()

feature_columns = get_feature_columns(cfg.output_filepath)
@@ -202,12 +203,7 @@ def test_tabularize():

# Step 2: Tabularization
tabularize_static_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -218,8 +214,11 @@ def test_tabularize():
overrides = [f"{k}={v}" for k, v in tabularize_static_config.items()]
cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml
tabularize_static.main(cfg)
output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz"))
actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files]

output_dir = Path(cfg.output_cohort_dir) / "tabularize"

output_files = list(output_dir.glob("**/static/**/*.npz"))
actual_files = [get_shard_prefix(output_dir, each) + ".npz" for each in output_files]
assert set(actual_files) == set(EXPECTED_STATIC_FILES)
# Check the files are not empty
for f in output_files:
@@ -252,9 +251,9 @@ def test_tabularize():
tabularize_time_series.main(cfg)

# confirm summary files exist:
output_files = list_subdir_files(cfg.output_dir, "npz")
output_files = list_subdir_files(str(output_dir.resolve()), "npz")
actual_files = [
get_shard_prefix(Path(cfg.output_dir), each) + ".npz"
get_shard_prefix(output_dir, each) + ".npz"
for each in output_files
if "none/static" not in str(each)
]
@@ -282,12 +281,7 @@ def test_tabularize():

# Step 3: Cache Task data
cache_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -299,31 +293,26 @@ def test_tabularize():
cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml

# Create fake labels
for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"):
for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "data", "parquet"):
df = pl.scan_parquet(f)
df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect()
pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]])
df = df.with_columns(pl.Series(name="label", values=pseudo_labels))
df = df.select(pl.col(["patient_id", "timestamp", "label"]))
df = df.select(pl.col(["patient_id", "time", "label"]))
df = df.with_row_index("event_id")

split = f.parent.stem
shard_num = f.stem
out_f = Path(cfg.input_label_dir) / Path(
get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f)
get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "data", f)
).with_suffix(".parquet")
out_f.parent.mkdir(parents=True, exist_ok=True)
df.write_parquet(out_f)

cache_task.main(cfg)

xgboost_config_kwargs = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
**shared_config,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": "[30d,365d,full]",
}
@@ -334,8 +323,10 @@ def test_tabularize():
overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()]
cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml

output_dir = Path(cfg.output_cohort_dir) / "model"

launch_xgboost.main(cfg)
output_files = list(Path(cfg.output_dir).glob("**/*.json"))
output_files = list(output_dir.glob("**/*.json"))
assert len(output_files) == 1


@@ -355,6 +346,7 @@ def test_xgboost_config():
stderr, stdout_agg = run_command("generate-subsets", ["[static/present]"], {}, "generate-subsets aggs")
xgboost_config_kwargs = {
"MEDS_cohort_dir": MEDS_cohort_dir,
"output_cohort_dir": "blah",
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
@@ -369,5 +361,4 @@ def test_xgboost_config():
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()]
cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml
print(cfg.tabularization.window_sizes)
assert cfg.tabularization.window_sizes