Skip to content

Commit

Permalink
Fixed test error with label re-processing.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 11, 2024
1 parent cf2a4e8 commit 9f24efe
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/MEDS_tabular_automl/scripts/cache_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def main(cfg: DictConfig):
np.random.shuffle(tabularization_tasks)

label_dir = Path(cfg.input_label_dir)
label_df = pl.scan_parquet(label_dir / "**/*.parquet").rename({"prediction_time": "time"})
label_df = pl.scan_parquet(label_dir / "**/*.parquet").rename(
{
"prediction_time": "time",
cfg.label_column: "label",
}
)

feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp)

Expand All @@ -105,7 +110,9 @@ def main(cfg: DictConfig):
.with_row_index("event_id")
.select("patient_id", "time", "event_id")
)
shard_label_df = label_df.join_asof(other=raw_data_df, by="patient_id", on="time")
shard_label_df = label_df.join(
raw_data_df.select("patient_id").unique(), on="patient_id", how="inner"
).join_asof(other=raw_data_df, by="patient_id", on="time")

shard_label_fp = Path(cfg.output_label_dir) / split / f"{shard_num}.parquet"
rwlock_wrap(
Expand Down

0 comments on commit 9f24efe

Please sign in to comment.