diff --git a/src/MEDS_tabular_automl/generate_summarized_reps.py b/src/MEDS_tabular_automl/generate_summarized_reps.py index e8d3e42..092ef6d 100644 --- a/src/MEDS_tabular_automl/generate_summarized_reps.py +++ b/src/MEDS_tabular_automl/generate_summarized_reps.py @@ -40,30 +40,59 @@ def f(c: str) -> str: return f -# def sparse_groupby_sum(df): -# id_cols = ["patient_id", "timestamp"] -# ohe = OneHotEncoder(sparse_output=True) -# # Get all other columns we are not grouping by -# other_columns = [col for col in df.columns if col not in id_cols] -# # Get a 607875 x nDistinctIDs matrix in sparse row format with exactly -# # 1 nonzero entry per row -# onehot = ohe.fit_transform(df[id_cols].values.reshape(-1, 1)) -# # Transpose it. then convert from sparse column back to sparse row, as -# # dot products of two sparse row matrices are faster than sparse col with -# # sparse row -# onehot = onehot.T.tocsr() -# # Dot the transposed matrix with the other columns of the df, converted to sparse row -# # format, then convert the resulting matrix back into a sparse -# # dataframe with the same column names -# out = pd.DataFrame.sparse.from_spmatrix( -# onehot.dot(df[other_columns].sparse.to_coo().tocsr()), -# columns=other_columns) -# # Add in the groupby column to this resulting dataframe with the proper class labels -# out[groupby] = ohe.categories_[0] -# # This final groupby sum simply ensures the result is in the format you would expect -# # for a regular pandas groupby and sum, but you can just return out if this is going to be -# # a performance penalty. Note in that case that the groupby column may have changed index -# return out.groupby(groupby).sum() +def sparse_aggregate(sparse_matrix, agg): + if agg == "sum": + merged_matrix = sparse_matrix.sum(axis=0) + elif agg == "min": + merged_matrix = sparse_matrix.min(axis=0) + elif agg == "max": + merged_matrix = sparse_matrix.max(axis=0) + elif agg == "sum_sqd": + merged_matrix = sparse_matrix.power(2).sum(axis=0) + elif agg == "count": + merged_matrix = sparse_matrix.getnnz(axis=0) + else: + raise ValueError(f"Aggregation method '{agg}' not implemented.") + return csr_matrix(merged_matrix) + + +def sum_merge_timestamps(df, sparse_matrix, agg): + """Groups by timestamp and combines rows that are on the same date. + + The combining is done by summing the rows in the sparse matrix that correspond to the same date. + + Args: + df (DataFrame): The DataFrame with 'timestamp' and 'patient_id'. + sparse_matrix (csr_matrix): The corresponding sparse matrix with data. + agg (str): Aggregation method, currently only 'sum' is implemented. + + Returns: + DataFrame, csr_matrix: Tuple containing the DataFrame with aggregated timestamps and the corresponding + sparse matrix. + """ + # Assuming 'timestamp' is already sorted; if not, uncomment the next line: + # df = df.sort_values(by='timestamp') + + # Group by timestamp and sum the data + grouped = df.groupby("timestamp") + indices = grouped.indices + + # Create a new sparse matrix with summed rows per unique timestamp + patient_id = df["patient_id"].iloc[0] + timestamps = [] + output_matrix = csr_matrix((0, sparse_matrix.shape[1]), dtype=sparse_matrix.dtype) + + # Loop through each group and sum + for timestamp, rows in indices.items(): + # Combine the rows in the sparse matrix for the current group (respecting the aggregation being used) + merged_matrix = sparse_aggregate(sparse_matrix[rows], agg) + # Save the non-zero elements + output_matrix = vstack([output_matrix, merged_matrix]) + timestamps.extend([timestamp]) + + # Create output DataFrame + out_df = pd.DataFrame({"patient_id": [patient_id] * len(timestamps), "timestamp": timestamps}) + return out_df, output_matrix def sparse_rolling(df, sparse_matrix, timedelta, agg): @@ -86,18 +115,12 @@ def sparse_rolling(df, sparse_matrix, timedelta, agg): patient_id = df.iloc[0].patient_id df = df.drop(columns="patient_id").reset_index(drop=True).reset_index() timestamps = [] - logger.info("rolling for patient_id") out_sparse_matrix = coo_matrix((0, sparse_matrix.shape[1]), dtype=sparse_matrix.dtype) for each in df[["index", "timestamp"]].rolling(on="timestamp", window=timedelta): - subset_matrix = sparse_matrix[each["index"]] - - # TODO this is where we would apply the aggregation timestamps.append(each.index.max()) - agg_subset_matrix = subset_matrix.sum(axis=0) + agg_subset_matrix = sparse_aggregate(sparse_matrix[each["index"]], agg) out_sparse_matrix = vstack([out_sparse_matrix, agg_subset_matrix]) out_df = pd.DataFrame({"patient_id": [patient_id] * len(timestamps), "timestamp": timestamps}) - # out_df = pd.concat([out_df, pd.DataFrame.sparse.from_spmatrix(out_sparse_matrix)], axis=1) - # out_df.columns = df.columns[1:] return out_df, out_sparse_matrix @@ -168,14 +191,12 @@ def compute_agg(df, window_size: str, agg: str): for patient_id, subset_df in tqdm(group.items(), total=len(group)): logger.info("sparse rolling setup") subset_sparse_matrix = sparse_matrix[subset_df.index] - patient_df = subset_df[ - ["patient_id", "timestamp"] - ] # pd.concat([subset_df[["patient_id", "timestamp"]], sparse_df], axis=1) + patient_df = subset_df[["patient_id", "timestamp"]] assert patient_df.timestamp.isnull().sum() == 0, "timestamp cannot be null" logger.info("sparse rolling start") + patient_df, subset_sparse_matrix = sum_merge_timestamps(patient_df, subset_sparse_matrix, agg) patient_df, out_sparse = sparse_rolling(patient_df, subset_sparse_matrix, timedelta, agg) logger.info("sparse rolling complete") - # patient_df["patient_id"] = patient_id out_dfs.append(patient_df) out_sparse_matrix = vstack([out_sparse_matrix, out_sparse]) out_df = pd.concat(out_dfs, axis=0)