Skip to content

Commit

Permalink
fixed bug with sparse matrix shape being too small for merging static…
Browse files Browse the repository at this point in the history
… and time series dataframs
  • Loading branch information
Oufattole committed May 31, 2024
1 parent 7668382 commit b6b8d43
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 33 deletions.
14 changes: 8 additions & 6 deletions scripts/tabularize_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ def merge_dfs(feature_columns, static_df, ts_df):
Returns:
- pd.DataFrame: A merged dataframe containing static and time-series features.
"""
# TODO - store static and ts data as numpy matrices
# TODO - Eventually do this duplication at the task specific stage after filtering patients and features
# Make static data sparse and merge it with the time-series data
logger.info("Make static data sparse and merge it with the time-series data")
assert static_df.patient_id.is_monotonic_increasing
assert ts_df.patient_id.is_monotonic_increasing
sparse_time_series = ts_df.drop(columns=["patient_id", "timestamp"]).sparse.to_coo()
duplication_index = ts_df["patient_id"].value_counts().sort_index()

num_patients = max(static_df.patient_id.nunique(), ts_df.patient_id.nunique())

# load static data as sparse matrix
static_matrix = static_df.drop(columns="patient_id").values
Expand All @@ -46,9 +49,8 @@ def merge_dfs(feature_columns, static_df, ts_df):
data_list.append(data)
rows.append(row)
cols.append(col)
static_matrix = csr_matrix(
(data_list, (rows, cols)), shape=(static_matrix.shape[0], static_matrix.shape[1])
)
static_matrix = csr_matrix((data_list, (rows, cols)), shape=(num_patients, static_matrix.shape[1]))
# Duplicate static matrix rows to match time-series data
duplication_index = ts_df["patient_id"].value_counts().sort_index().reset_index(drop=True)
reindex_slices = np.repeat(duplication_index.index.values, duplication_index.values)
static_matrix = static_matrix[reindex_slices, :]
Expand Down Expand Up @@ -81,7 +83,7 @@ def merge_dfs(feature_columns, static_df, ts_df):


@hydra.main(version_base=None, config_path="../configs", config_name="tabularize")
def tabularize_ts_data(
def merge_data(
cfg: DictConfig,
):
"""Processes a medical dataset to generates and stores flat representatiosn of time-series data.
Expand Down Expand Up @@ -146,4 +148,4 @@ def write_fn(data, out_df):


if __name__ == "__main__":
tabularize_ts_data()
merge_data()
89 changes: 62 additions & 27 deletions tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from scripts.identify_columns import store_columns
from scripts.summarize_over_windows import summarize_ts_data_over_windows
from scripts.tabularize_merge import merge_data
from scripts.tabularize_static import tabularize_static_data
from scripts.tabularize_ts import tabularize_ts_data

Expand Down Expand Up @@ -102,6 +103,60 @@
"tuning/0": MEDS_TUNING_0,
}

SUMMARIZE_EXPECTED_FILES = [
"train/365d/value/sum/0.pkl",
"train/365d/value/sum/1.pkl",
"train/365d/code/count/0.pkl",
"train/365d/code/count/1.pkl",
"train/full/value/sum/0.pkl",
"train/full/value/sum/1.pkl",
"train/full/code/count/0.pkl",
"train/full/code/count/1.pkl",
"train/30d/value/sum/0.pkl",
"train/30d/value/sum/1.pkl",
"train/30d/code/count/0.pkl",
"train/30d/code/count/1.pkl",
"held_out/365d/value/sum/0.pkl",
"held_out/365d/code/count/0.pkl",
"held_out/full/value/sum/0.pkl",
"held_out/full/code/count/0.pkl",
"held_out/30d/value/sum/0.pkl",
"held_out/30d/code/count/0.pkl",
"tuning/365d/value/sum/0.pkl",
"tuning/365d/code/count/0.pkl",
"tuning/full/value/sum/0.pkl",
"tuning/full/code/count/0.pkl",
"tuning/30d/value/sum/0.pkl",
"tuning/30d/code/count/0.pkl",
]

MERGE_EXPECTED_FILES = [
"train/365d/value/sum/0.npy",
"train/365d/value/sum/1.npy",
"train/365d/code/count/0.npy",
"train/365d/code/count/1.npy",
"train/full/value/sum/0.npy",
"train/full/value/sum/1.npy",
"train/full/code/count/0.npy",
"train/full/code/count/1.npy",
"train/30d/value/sum/0.npy",
"train/30d/value/sum/1.npy",
"train/30d/code/count/0.npy",
"train/30d/code/count/1.npy",
"held_out/365d/value/sum/0.npy",
"held_out/365d/code/count/0.npy",
"held_out/full/value/sum/0.npy",
"held_out/full/code/count/0.npy",
"held_out/30d/value/sum/0.npy",
"held_out/30d/code/count/0.npy",
"tuning/365d/value/sum/0.npy",
"tuning/365d/code/count/0.npy",
"tuning/full/value/sum/0.npy",
"tuning/full/code/count/0.npy",
"tuning/30d/value/sum/0.npy",
"tuning/30d/code/count/0.npy",
]


def test_tabularize():
with tempfile.TemporaryDirectory() as d:
Expand Down Expand Up @@ -177,33 +232,13 @@ def test_tabularize():
# confirm summary files exist:
output_files = list(tabularized_data_dir.glob("ts/*/*/*/*/*.pkl"))
actual_files = [str(Path(*f.parts[-5:])) for f in output_files]
expected_files = [
"train/365d/value/sum/0.pkl",
"train/365d/value/sum/1.pkl",
"train/365d/code/count/0.pkl",
"train/365d/code/count/1.pkl",
"train/full/value/sum/0.pkl",
"train/full/value/sum/1.pkl",
"train/full/code/count/0.pkl",
"train/full/code/count/1.pkl",
"train/30d/value/sum/0.pkl",
"train/30d/value/sum/1.pkl",
"train/30d/code/count/0.pkl",
"train/30d/code/count/1.pkl",
"held_out/365d/value/sum/0.pkl",
"held_out/365d/code/count/0.pkl",
"held_out/full/value/sum/0.pkl",
"held_out/full/code/count/0.pkl",
"held_out/30d/value/sum/0.pkl",
"held_out/30d/code/count/0.pkl",
"tuning/365d/value/sum/0.pkl",
"tuning/365d/code/count/0.pkl",
"tuning/full/value/sum/0.pkl",
"tuning/full/code/count/0.pkl",
"tuning/30d/value/sum/0.pkl",
"tuning/30d/code/count/0.pkl",
]
assert set(actual_files) == set(expected_files)

assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES)
for f in output_files:
df = pd.read_pickle(f)
assert df.shape[0] > 0

merge_data(cfg)
output_files = list(tabularized_data_dir.glob("sparse/*/*/*/*/*.npy"))
actual_files = [str(Path(*f.parts[-5:])) for f in output_files]
assert set(actual_files) == set(MERGE_EXPECTED_FILES)

0 comments on commit b6b8d43

Please sign in to comment.