Skip to content

Commit

Permalink
improved speed, by removing conversion from sparse scipy matrix to sp…
Browse files Browse the repository at this point in the history
…arse pandas array for each patient, now we just use sparse scipy matrices
  • Loading branch information
Oufattole committed May 29, 2024
1 parent f125600 commit 2acc3bc
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def f(c: str) -> str:
# return out.groupby(groupby).sum()


def sparse_rolling(df, timedelta, agg):
def sparse_rolling(df, sparse_matrix, timedelta, agg):
"""Iterates through rolling windows while maintaining sparsity.
Example:
Expand All @@ -83,22 +83,22 @@ def sparse_rolling(df, timedelta, agg):
timestamp datetime64[ns]
dtype: object
"""
patient_id = df.iloc[0].patient_id
df = df.drop(columns="patient_id").reset_index(drop=True).reset_index()
timestamps = []
logger.info("rolling for patient_id")
sparse_matrix = csr_matrix(df[df.columns[2:]].sparse.to_coo())
out_sparse_matrix = coo_matrix((0, sparse_matrix.shape[1]), dtype=sparse_matrix.dtype)
for each in tqdm(df[["index", "timestamp"]].rolling(on="timestamp", window=timedelta), total=len(df)):
for each in df[["index", "timestamp"]].rolling(on="timestamp", window=timedelta):
subset_matrix = sparse_matrix[each["index"]]

# TODO this is where we would apply the aggregation
timestamps.append(each.index.max())
agg_subset_matrix = subset_matrix.sum(axis=0)
out_sparse_matrix = vstack([out_sparse_matrix, agg_subset_matrix])
out_df = pd.DataFrame({"timestamp": timestamps})
out_df = pd.concat([out_df, pd.DataFrame.sparse.from_spmatrix(out_sparse_matrix)], axis=1)
out_df.columns = df.columns[1:]
return out_df
out_df = pd.DataFrame({"patient_id": [patient_id] * len(timestamps), "timestamp": timestamps})
# out_df = pd.concat([out_df, pd.DataFrame.sparse.from_spmatrix(out_sparse_matrix)], axis=1)
# out_df.columns = df.columns[1:]
return out_df, out_sparse_matrix


def compute_agg(df, window_size: str, agg: str):
Expand Down Expand Up @@ -160,22 +160,29 @@ def compute_agg(df, window_size: str, agg: str):
sparse_matrix = df[df.columns[2:]].sparse.to_coo()
sparse_matrix = csr_matrix(sparse_matrix)
logger.info("done grouping")
out_sparse_matrix = coo_matrix((0, sparse_matrix.shape[1]), dtype=sparse_matrix.dtype)
match agg:
case "code/count" | "value/sum":
agg = "sum"
out_dfs = []
for patient_id, subset_df in group.items():
logger.info(f"rolling for patient_id {patient_id}")
for patient_id, subset_df in tqdm(group.items(), total=len(group)):
logger.info("sparse rolling setup")
subset_sparse_matrix = sparse_matrix[subset_df.index]
sparse_df = pd.DataFrame.sparse.from_spmatrix(subset_sparse_matrix)
sparse_df.index = subset_df.index
patient_df = pd.concat([subset_df[["patient_id", "timestamp"]], sparse_df], axis=1)
patient_df.columns = df.columns
patient_df = subset_df[
["patient_id", "timestamp"]
] # pd.concat([subset_df[["patient_id", "timestamp"]], sparse_df], axis=1)
assert patient_df.timestamp.isnull().sum() == 0, "timestamp cannot be null"
patient_df = sparse_rolling(patient_df, timedelta, agg)
patient_df["patient_id"] = patient_id
logger.info("sparse rolling start")
patient_df, out_sparse = sparse_rolling(patient_df, subset_sparse_matrix, timedelta, agg)
logger.info("sparse rolling complete")
# patient_df["patient_id"] = patient_id
out_dfs.append(patient_df)
out_sparse_matrix = vstack([out_sparse_matrix, out_sparse])
out_df = pd.concat(out_dfs, axis=0)
out_df = pd.concat(
[out_df.reset_index(drop=True), pd.DataFrame.sparse.from_spmatrix(out_sparse_matrix)], axis=1
)
out_df.columns = df.columns
out_df.rename(columns=time_aggd_col_alias_fntr(window_size, "count"))

case _:
Expand Down

0 comments on commit 2acc3bc

Please sign in to comment.