Skip to content

Commit

Permalink
fixed bug with summarization script crashing for min and max value ag…
Browse files Browse the repository at this point in the history
…gregations due to a coo matrix being returned rather than a dense matrix as with sum and count operations
  • Loading branch information
Nassim Oufattole committed Jun 2, 2024
1 parent c225c47 commit cb21821
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,17 @@ def aggregate_matrix(windows, matrix, agg, num_features, use_tqdm=False):
max_index = window["max_index"]
subset_matrix = matrix[min_index : max_index + 1, :]
agg_matrix = sparse_aggregate(subset_matrix, agg).astype(out_dtype)
nozero_ind = np.nonzero(agg_matrix)[0]
col.append(nozero_ind)
data.append(agg_matrix[nozero_ind])
row.append(np.repeat(np.array(i, dtype=np.int32), len(nozero_ind)))
if isinstance(agg_matrix, np.ndarray):
nozero_ind = np.nonzero(agg_matrix)[0]
col.append(nozero_ind)
data.append(agg_matrix[nozero_ind])
row.append(np.repeat(np.array(i, dtype=np.int32), len(nozero_ind)))
elif isinstance(agg_matrix, coo_array):
col.append(agg_matrix.col)
data.append(agg_matrix.data)
row.append(agg_matrix.row)
else:
raise TypeError(f"Invalid matrix type {type(agg_matrix)}")
row = np.concatenate(row)
out_matrix = coo_array(
(np.concatenate(data), (row, np.concatenate(col))),
Expand Down Expand Up @@ -355,16 +362,16 @@ def generate_summary(
df = pl.scan_parquet(
Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed")
/ "final_cohort"
/ "held_out"
/ "7.parquet"
/ "train"
/ "3.parquet"
)
agg = "code/count"
agg = "value/min"
index_df, sparse_matrix = get_flat_ts_rep(agg, feature_columns, df)
generate_summary(
feature_columns=feature_columns,
index_df=index_df,
matrix=sparse_matrix,
window_size="1d",
window_size="30d",
agg=agg,
use_tqdm=True,
)

0 comments on commit cb21821

Please sign in to comment.